aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/cpu
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/cpu')
-rw-r--r--src/runtime/cpu/operators/CpuDepthwiseConv2d.cpp16
-rw-r--r--src/runtime/cpu/operators/CpuDepthwiseConv2d.h5
-rw-r--r--src/runtime/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp520
-rw-r--r--src/runtime/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.h27
-rw-r--r--src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp37
5 files changed, 66 insertions, 539 deletions
diff --git a/src/runtime/cpu/operators/CpuDepthwiseConv2d.cpp b/src/runtime/cpu/operators/CpuDepthwiseConv2d.cpp
index 160a9fd70b..f577e94def 100644
--- a/src/runtime/cpu/operators/CpuDepthwiseConv2d.cpp
+++ b/src/runtime/cpu/operators/CpuDepthwiseConv2d.cpp
@@ -62,8 +62,8 @@ Status validate_arguments_optimized(const ITensorInfo *src, const ITensorInfo *w
ARM_COMPUTE_RETURN_ON_ERROR(CpuDepthwiseConv2dAssemblyDispatch::validate(src, weights, biases, dst, info));
- //Validate Activation Layer
- if(info.act_info.enabled())
+ // Validate Activation Layer
+ if(info.act_info.enabled() && !CpuDepthwiseConv2dAssemblyDispatch::is_activation_supported(info.act_info))
{
ARM_COMPUTE_RETURN_ON_ERROR(CpuActivation::validate(dst, nullptr, info.act_info));
}
@@ -95,15 +95,7 @@ void CpuDepthwiseConv2d::CpuDepthwiseConv2dOptimizedInternal::configure(ITensorI
_is_prepared = false;
// Configure pipeline
- ActivationLayerInfo act_info_to_use = ActivationLayerInfo();
- const bool is_relu = arm_compute::utils::info_helpers::is_relu(info.act_info);
- const bool is_relu6 = arm_compute::utils::info_helpers::is_relu6(info.act_info);
- _is_activationlayer_enabled = info.act_info.enabled() && !(is_relu || is_relu6);
-
- if(!_is_activationlayer_enabled)
- {
- act_info_to_use = info.act_info;
- }
+ _is_activationlayer_enabled = info.act_info.enabled() && !CpuDepthwiseConv2dAssemblyDispatch::is_activation_supported(info.act_info);
_dwc_optimized_func = std::make_unique<CpuDepthwiseConv2dAssemblyDispatch>();
if(_is_nchw)
@@ -359,7 +351,7 @@ Status CpuDepthwiseConv2d::CpuDepthwiseConv2dGeneric::validate(const ITensorInfo
}
// Validate Activation Layer
- if(info.act_info.enabled())
+ if(info.act_info.enabled() && !CpuDepthwiseConv2dAssemblyDispatch::is_activation_supported(info.act_info))
{
ARM_COMPUTE_RETURN_ON_ERROR(CpuActivation::validate(dst, nullptr, info.act_info));
}
diff --git a/src/runtime/cpu/operators/CpuDepthwiseConv2d.h b/src/runtime/cpu/operators/CpuDepthwiseConv2d.h
index 049397fe60..ae9f894aab 100644
--- a/src/runtime/cpu/operators/CpuDepthwiseConv2d.h
+++ b/src/runtime/cpu/operators/CpuDepthwiseConv2d.h
@@ -92,9 +92,8 @@ private:
*
* -# @ref NEFillBorderKernel (if pad_x or pad_y > 0) and no assembly kernel implementation is present
* -# @ref CpuDepthwiseConv2d3x3Kernel if 3x3 and no assembly kernel implementation is present
- * -# @ref NEDepthwiseConvolutionAssemblyDispatch if assembly kernel implementation is present
- * -# @ref NEDirectConvolutionLayerOutputStageKernel if re-quantization of dst is required
- * -# @ref NEActivationLayer if fused activation is required
+ * -# @ref CpuDepthwiseConv2dAssemblyDispatch if assembly kernel implementation is present
+ * -# @ref CpuActivation if fused activation is required
*
*/
class CpuDepthwiseConv2dOptimizedInternal : public ICpuOperator
diff --git a/src/runtime/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp b/src/runtime/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp
index a36ee1d45b..660ac0163c 100644
--- a/src/runtime/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp
+++ b/src/runtime/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp
@@ -24,315 +24,22 @@
#include "src/runtime/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.h"
-#include "arm_compute/core/ITensor.h"
-#include "arm_compute/core/Utils.h"
-#include "arm_compute/core/utils/misc/InfoHelpers.h"
-#include "arm_compute/core/utils/misc/ShapeCalculator.h"
-#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
+#include "arm_compute/core/ITensorInfo.h"
+#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "src/core/CPP/Validate.h"
-#include "src/core/NEON/kernels/assembly/NEDepthwiseConvolutionAssemblyKernelWrapper.h"
-#include "src/core/NEON/kernels/convolution/depthwise/depthwise_dilated.hpp"
-#include "src/core/NEON/kernels/convolution/depthwise/depthwise_quantized_dilated.hpp"
+#include "src/core/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.h"
#include "src/core/helpers/AutoConfiguration.h"
-
-#include "arm_compute/runtime/NEON/NEScheduler.h"
-
-#include <set>
+#include "src/core/utils/AssemblyUtils.h"
namespace arm_compute
{
namespace cpu
{
-namespace
-{
-std::unique_ptr<depthwise::IDepthwiseConvolution> get_qasymm8_convolver(int kernel_size, int stride_x,
- int n_batches, int in_rows, int in_cols, int n_channels,
- int dilation_factor, neon_convolution_kernels::ActivationFunction activation,
- const qasymm8::QAsymm8Params &wqinfo, const qasymm8::QAsymm8Params &iqinfo, const qasymm8::QAsymm8Params &oqinfo,
- const qasymm8::QAsymm8RescaleParams &rescale_params,
- int padding_top, int padding_left, int padding_bottom, int padding_right)
-{
- switch(kernel_size)
- {
- case 3:
- {
- switch(stride_x)
- {
- case 1:
- return std::make_unique<depthwise::QAsymm8DilatedDepthwiseConvolution<2, 2, 3, 3, 1, 1>>(
- n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right);
- case 2:
- return std::make_unique<depthwise::QAsymm8DilatedDepthwiseConvolution<2, 2, 3, 3, 2, 2>>(
- n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right);
- default:
- return nullptr;
- }
- }
- case 5:
- {
- switch(stride_x)
- {
- case 1:
- return std::make_unique<depthwise::QAsymm8DilatedDepthwiseConvolution<2, 2, 5, 5, 1, 1>>(
- n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right);
- case 2:
- return std::make_unique<depthwise::QAsymm8DilatedDepthwiseConvolution<2, 2, 5, 5, 2, 2>>(
- n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right);
- default:
- return nullptr;
- }
- }
- default:
- return nullptr;
- }
-}
-
-std::unique_ptr<depthwise::IDepthwiseConvolution> get_qsymm8_perchannel_convolver(int kernel_size, int stride_x,
- int n_batches, int in_rows, int in_cols, int n_channels,
- neon_convolution_kernels::ActivationFunction activation,
- const qsymm8::QSymm8PerChannelParams &wqinfo, const qasymm8::QAsymm8Params &iqinfo, const qasymm8::QAsymm8Params &oqinfo,
- const qsymm8::QSymm8PerChannelRescaleParams &rescale_params,
- int padding_top, int padding_left, int padding_bottom, int padding_right)
-{
- switch(kernel_size)
- {
- case 3:
- {
- switch(stride_x)
- {
- case 1:
- return std::make_unique<depthwise::QSymm8HybridPerChannelDepthwiseConvolution<2, 2, 3, 3, 1, 1>>(
- n_batches, in_rows, in_cols, n_channels, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right);
- case 2:
- return std::make_unique<depthwise::QSymm8HybridPerChannelDepthwiseConvolution<2, 2, 3, 3, 2, 2>>(
- n_batches, in_rows, in_cols, n_channels, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right);
- default:
- return nullptr;
- }
- }
- case 5:
- {
- switch(stride_x)
- {
- case 1:
- return std::make_unique<depthwise::QSymm8HybridPerChannelDepthwiseConvolution<2, 2, 5, 5, 1, 1>>(
- n_batches, in_rows, in_cols, n_channels, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right);
- case 2:
- return std::make_unique<depthwise::QSymm8HybridPerChannelDepthwiseConvolution<2, 2, 5, 5, 2, 2>>(
- n_batches, in_rows, in_cols, n_channels, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right);
- default:
- return nullptr;
- }
- }
- default:
- return nullptr;
- }
-}
-
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-std::unique_ptr<depthwise::IDepthwiseConvolution> get_fp16_convolver(int kernel_size, int stride_x,
- int n_batches, int in_rows, int in_cols, int n_channels,
- int dilation_factor, neon_convolution_kernels::ActivationFunction activation,
- int padding_top, int padding_left, int padding_bottom, int padding_right)
-{
- switch(kernel_size)
- {
- case 3:
- {
- switch(stride_x)
- {
- case 1:
- return std::make_unique<depthwise::DilatedDepthwiseConvolution<3, 3, 3, 3, 1, 1, float16_t, float16_t, float16_t>>(
- n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right);
- case 2:
- return std::make_unique<depthwise::DilatedDepthwiseConvolution<3, 3, 3, 3, 2, 2, float16_t, float16_t, float16_t>>(
- n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right);
- default:
- return nullptr;
- }
- }
- case 5:
- {
- switch(stride_x)
- {
- case 1:
- return std::make_unique<depthwise::DilatedDepthwiseConvolution<3, 3, 5, 5, 1, 1, float16_t, float16_t, float16_t>>(
- n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right);
- case 2:
- return std::make_unique<depthwise::DilatedDepthwiseConvolution<3, 3, 5, 5, 2, 2, float16_t, float16_t, float16_t>>(
- n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right);
- default:
- return nullptr;
- }
- }
- default:
- return nullptr;
- }
-}
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-
-std::unique_ptr<depthwise::IDepthwiseConvolution> get_fp32_convolver(int kernel_size, int stride_x,
- int n_batches, int in_rows, int in_cols, int n_channels,
- int dilation_factor, neon_convolution_kernels::ActivationFunction activation,
- int padding_top, int padding_left, int padding_bottom, int padding_right)
-{
- switch(kernel_size)
- {
- case 3:
- {
- switch(stride_x)
- {
- case 1:
- return std::make_unique<depthwise::DilatedDepthwiseConvolution<4, 4, 3, 3, 1, 1, float, float, float>>(
- n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right);
- case 2:
- return std::make_unique<depthwise::DilatedDepthwiseConvolution<3, 3, 3, 3, 2, 2, float, float, float>>(
- n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right);
- default:
- return nullptr;
- }
- }
- case 5:
- {
- switch(stride_x)
- {
- case 1:
- return std::make_unique<depthwise::DilatedDepthwiseConvolution<4, 4, 5, 5, 1, 1, float, float, float>>(
- n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right);
- case 2:
- return std::make_unique<depthwise::DilatedDepthwiseConvolution<3, 3, 5, 5, 2, 2, float, float, float>>(
- n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right);
- default:
- return nullptr;
- }
- }
- default:
- return nullptr;
- }
-}
-
-std::unique_ptr<depthwise::IDepthwiseConvolution> create_convolver(const ITensorInfo *src,
- const ITensorInfo *weights,
- ITensorInfo *output,
- const ConvolutionInfo &info)
-{
- const DataType data_type = src->data_type();
- const TensorShape shape = src->tensor_shape();
-
- const int n_batches = shape[3];
- const int in_rows = shape.z();
- const int in_cols = shape.y();
- const int n_channels = shape.x();
- const int dilation_factor = info.dilation.x();
- const int padding_top = info.pad_stride_info.pad_top();
- const int padding_left = info.pad_stride_info.pad_left();
- const int padding_bottom = info.pad_stride_info.pad_bottom();
- const int padding_right = info.pad_stride_info.pad_right();
-
- const bool is_uniform_quantized = (data_type == DataType::QASYMM8) && (weights->data_type() == DataType::QASYMM8);
- const bool is_perchannel_quantized = (data_type == DataType::QASYMM8) && (weights->data_type() == DataType::QSYMM8_PER_CHANNEL);
-
- const unsigned int stride_x = info.pad_stride_info.stride().first;
- const unsigned int kernel_size = weights->tensor_shape().y();
-
- // Map activation function
- neon_convolution_kernels::ActivationFunction activation = neon_convolution_kernels::ActivationFunction::None;
- if(arm_compute::utils::info_helpers::is_relu(info.act_info))
- {
- activation = neon_convolution_kernels::ActivationFunction::ReLU;
- }
- else if(arm_compute::utils::info_helpers::is_relu6(info.act_info))
- {
- activation = neon_convolution_kernels::ActivationFunction::ReLU6;
- }
-
- // Create quantized convolver
- if(is_uniform_quantized)
- {
- const UniformQuantizationInfo input_qinfo = src->quantization_info().uniform();
- const UniformQuantizationInfo weights_qinfo = weights->quantization_info().uniform();
- const UniformQuantizationInfo output_qinfo = output->quantization_info().uniform();
-
- // Check that quantization info are in the range [0, 255]
- ARM_COMPUTE_ERROR_ON(input_qinfo.offset < 0 || input_qinfo.offset > 255);
- ARM_COMPUTE_ERROR_ON(weights_qinfo.offset < 0 || weights_qinfo.offset > 255);
- ARM_COMPUTE_ERROR_ON(output_qinfo.offset < 0 || output_qinfo.offset > 255);
- const qasymm8::QAsymm8Params iqinfo{ static_cast<uint8_t>(input_qinfo.offset), input_qinfo.scale };
- const qasymm8::QAsymm8Params wqinfo{ static_cast<uint8_t>(weights_qinfo.offset), weights_qinfo.scale };
- const qasymm8::QAsymm8Params oqinfo{ static_cast<uint8_t>(output_qinfo.offset), output_qinfo.scale };
-
- // Calculate rescale parameters
- const float fmultipler = iqinfo.scale * wqinfo.scale / oqinfo.scale;
- int32_t qmultiplier = 0;
- int32_t qshift = 0;
- quantization::calculate_quantized_multiplier_less_than_one(fmultipler, &qmultiplier, &qshift);
- qasymm8::QAsymm8RescaleParams rescale_params(qshift, qmultiplier, fmultipler);
-
- return get_qasymm8_convolver(kernel_size, stride_x, n_batches, in_rows, in_cols, n_channels, dilation_factor, activation,
- wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right);
- }
- else if(is_perchannel_quantized)
- {
- const UniformQuantizationInfo input_qinfo = src->quantization_info().uniform();
- const QuantizationInfo weights_qinfo = weights->quantization_info();
- const UniformQuantizationInfo output_qinfo = output->quantization_info().uniform();
-
- // Check that quantization info are in the range [0, 255]
- ARM_COMPUTE_ERROR_ON(input_qinfo.offset < 0 || input_qinfo.offset > 255);
- ARM_COMPUTE_ERROR_ON(output_qinfo.offset < 0 || output_qinfo.offset > 255);
- const qasymm8::QAsymm8Params iqinfo{ static_cast<uint8_t>(input_qinfo.offset), input_qinfo.scale };
- const qsymm8::QSymm8PerChannelParams wqinfo{ weights_qinfo.scale() };
- const qasymm8::QAsymm8Params oqinfo{ static_cast<uint8_t>(output_qinfo.offset), output_qinfo.scale };
-
- // Calculate rescale parameters
- std::vector<float> fmultipliers;
- std::vector<int32_t> qmultipliers;
- std::vector<int32_t> qshifts;
-
- for(auto const s : wqinfo.scales)
- {
- const float fmultipler = iqinfo.scale * s / oqinfo.scale;
- int32_t qmultiplier = 0;
- int32_t qshift = 0;
- quantization::calculate_quantized_multiplier_less_than_one(fmultipler, &qmultiplier, &qshift);
- fmultipliers.push_back(fmultipler);
- qmultipliers.push_back(qmultiplier);
- qshifts.push_back(qshift);
- }
-
- qsymm8::QSymm8PerChannelRescaleParams rescale_params(qshifts, qmultipliers, fmultipliers);
-
- return get_qsymm8_perchannel_convolver(kernel_size, stride_x, n_batches, in_rows, in_cols, n_channels, activation,
- wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right);
- }
- else
- {
- // Create float convolver
- switch(data_type)
- {
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- case DataType::F16:
- {
- return get_fp16_convolver(kernel_size, stride_x, n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right);
- }
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- case DataType::F32:
- {
- return get_fp32_convolver(kernel_size, stride_x, n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right);
- }
- default:
- return nullptr;
- }
- }
-}
-} // namespace
-
struct CpuDepthwiseConv2dAssemblyDispatch::LocalImpl
{
- std::unique_ptr<depthwise::IDepthwiseConvolution> dwc_assembly_kernel{ nullptr };
- NEDepthwiseConvolutionAssemblyKernelWrapper dwc_acl_kernel{};
- bool is_prepared{ false };
- experimental::MemoryRequirements mem_req{};
+ std::unique_ptr<kernels::CpuDepthwiseConv2dAssemblyWrapperKernel> asm_kernel{ nullptr };
+ bool is_prepared{ false };
+ experimental::MemoryRequirements mem_req{};
};
#ifndef DOXYGEN_SKIP_THIS
@@ -350,206 +57,71 @@ void CpuDepthwiseConv2dAssemblyDispatch::configure(const ITensorInfo *src,
ITensorInfo *dst,
const ConvolutionInfo &info)
{
- ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst);
- ARM_COMPUTE_UNUSED(bias);
- ARM_COMPUTE_ERROR_THROW_ON(CpuDepthwiseConv2dAssemblyDispatch::validate(src,
- weights,
- bias != nullptr ? bias : nullptr,
- dst,
- info));
-
- // Output auto inizialitation if not yet initialized
- const TensorShape dst_shape = misc::shape_calculator::compute_depthwise_convolution_shape(*src, *weights, info);
- auto_init_if_empty(*dst, src->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(dst_shape).set_quantization_info(dst->quantization_info()));
-
- _pImpl->is_prepared = false;
+ const CPUInfo &ci = NEScheduler::get().cpu_info();
+ const unsigned int num_threads = NEScheduler::get().num_threads();
+ _pImpl->is_prepared = false;
- // Create convolver
- _pImpl->dwc_assembly_kernel = create_convolver(src, weights, dst, info);
- ARM_COMPUTE_ERROR_ON(_pImpl->dwc_assembly_kernel == nullptr);
-
- // Create assembly kernel wrapper
- _pImpl->dwc_acl_kernel.configure(_pImpl->dwc_assembly_kernel.get());
-
- constexpr size_t alignment = 128;
-
- // Create workspace
- const unsigned int num_threads = NEScheduler::get().num_threads();
- const size_t workspace_size = _pImpl->dwc_assembly_kernel->get_working_space_size(num_threads);
- ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "Workspace size cannot be 0 !");
- _pImpl->mem_req.push_back({ TensorType::ACL_INT_0, workspace_size, alignment });
+ // If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured()
+ if(!CpuDepthwiseConv2dAssemblyDispatch::validate(src, weights, bias, dst, info))
+ {
+ return;
+ }
- // Create packing tensor
- const size_t pack_tensor_size = _pImpl->dwc_assembly_kernel->get_packed_params_size();
- ARM_COMPUTE_ERROR_ON_MSG(pack_tensor_size == 0, "Pack tensor size cannot be 0 !");
+ auto dwc_wrapper = std::make_unique<kernels::CpuDepthwiseConv2dAssemblyWrapperKernel>();
+ ARM_COMPUTE_ERROR_ON(dwc_wrapper == nullptr);
+ dwc_wrapper->configure(src, weights, bias, dst, info, ci);
- _pImpl->mem_req.push_back({ TensorType::ACL_INT_1, pack_tensor_size, alignment });
+ // Compute memory requirements for assembly kernels
+ constexpr size_t alignment = 4096;
+ _pImpl->mem_req.push_back({ TensorType::ACL_INT_0, dwc_wrapper->get_working_size(num_threads, src->dimension(0)), alignment });
+ _pImpl->mem_req.push_back({ TensorType::ACL_INT_1, dwc_wrapper->get_storage_size(), alignment });
+ _pImpl->asm_kernel = std::move(dwc_wrapper);
}
-experimental::MemoryRequirements CpuDepthwiseConv2dAssemblyDispatch::workspace() const
+Status CpuDepthwiseConv2dAssemblyDispatch::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *bias, const ITensorInfo *dst, const ConvolutionInfo &info)
{
- return _pImpl->mem_req;
+ return kernels::CpuDepthwiseConv2dAssemblyWrapperKernel::validate(src, weights, bias, dst, info);
}
-Status CpuDepthwiseConv2dAssemblyDispatch::validate(const ITensorInfo *src,
- const ITensorInfo *weights,
- const ITensorInfo *bias,
- const ITensorInfo *dst,
- const ConvolutionInfo &info)
+experimental::MemoryRequirements CpuDepthwiseConv2dAssemblyDispatch::workspace() const
{
- ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
- if(weights->data_type() != DataType::QSYMM8_PER_CHANNEL)
- {
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights);
- }
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(src, weights);
-
- // Validate convolver
- ARM_COMPUTE_RETURN_ERROR_ON(!is_optimized_supported(src, weights, info));
-
- // Validate activation
- const bool is_relu = arm_compute::utils::info_helpers::is_relu(info.act_info);
- const bool is_relu6 = arm_compute::utils::info_helpers::is_relu6(info.act_info);
- ARM_COMPUTE_RETURN_ERROR_ON(info.act_info.enabled() && !(is_relu || is_relu6));
-
- // Check bias
- if(bias != nullptr)
- {
- unsigned int channel_idx = get_data_layout_dimension_index(src->data_layout(), DataLayoutDimension::CHANNEL);
- ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() > 1);
- ARM_COMPUTE_RETURN_ERROR_ON(bias->dimension(0) != weights->dimension(channel_idx));
- }
-
- // Check output
- if(dst->total_size() != 0)
- {
- const TensorShape dst_shape = misc::shape_calculator::compute_depthwise_convolution_shape(*src, *weights, info);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), dst_shape);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst);
- }
-
- // The uniform quantization case will only have 1 scale value in the weights quantization info
- const UniformQuantizationInfo src_qinfo = src->quantization_info().uniform();
- const QuantizationInfo weights_qinfo = weights->quantization_info();
- const UniformQuantizationInfo dst_qinfo = dst->quantization_info().uniform();
- for(auto const s : weights_qinfo.scale())
- {
- const float fmultipler = src_qinfo.scale * s / dst_qinfo.scale;
- ARM_COMPUTE_RETURN_ERROR_ON(fmultipler > 1.f);
- }
-
- return Status{};
+ return _pImpl->mem_req;
}
-bool CpuDepthwiseConv2dAssemblyDispatch::is_optimized_supported(const ITensorInfo *src,
- const ITensorInfo *weights,
- const ConvolutionInfo &info)
+bool CpuDepthwiseConv2dAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation)
{
- ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights);
-
- // Reshape input shape if in NHWC format
- const DataLayout data_layout = src->data_layout();
- TensorShape in_shape{ src->tensor_shape() };
- if(data_layout == DataLayout::NHWC)
- {
- in_shape.set(Window::DimX, src->tensor_shape().y());
- in_shape.set(Window::DimY, src->tensor_shape().z());
- in_shape.set(Window::DimZ, src->tensor_shape().x());
- }
-
- // Check data type
- const DataType input_type = src->data_type();
- const bool is_input_type_valid = is_data_type_float(input_type) || input_type == DataType::QASYMM8;
- const DataType weights_type = weights->data_type();
- const bool is_weights_type_valid = is_data_type_float(weights_type) || weights_type == DataType::QASYMM8 || weights_type == DataType::QASYMM8_SIGNED
- || weights_type == DataType::QSYMM8_PER_CHANNEL;
-
- // Check weighs size
- std::set<unsigned int> supported_kernel_sizes = { 3, 5 };
- const unsigned int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
- const unsigned int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
- const unsigned int kernel_w = weights->dimension(width_idx);
- const unsigned int kernel_h = weights->dimension(height_idx);
- bool weights_supported = (kernel_w == kernel_h) && (supported_kernel_sizes.count(kernel_w) != 0);
-
- // Check for supported strides
- const auto &strides = info.pad_stride_info.stride();
- bool supported_strides = (strides.first == strides.second) && ((strides.first == 1) || (strides.first == 2));
-
- // Check for supported padding
- const auto pad_top = info.pad_stride_info.pad_top();
- const auto pad_right = info.pad_stride_info.pad_right();
- const auto pad_bottom = info.pad_stride_info.pad_bottom();
- const auto pad_left = info.pad_stride_info.pad_left();
- PadStrideInfo same_pad = calculate_same_pad(in_shape, TensorShape(kernel_w, kernel_h), info.pad_stride_info, DataLayout::NCHW, info.dilation);
- bool is_same_padding = (pad_top == same_pad.pad_top()) && (pad_right == same_pad.pad_right()) && (pad_bottom == same_pad.pad_bottom()) && (pad_left == same_pad.pad_left());
- bool is_valid_padding = (pad_top == 0) && (pad_right == 0) && (pad_bottom == 0) && (pad_left == 0);
- bool supported_padding = is_same_padding || is_valid_padding;
- // TODO(COMPMID-2464): Enable once dilated conv with stride 2 is supported
- bool is_dilation_supported = ((info.dilation == Size2D(1U, 1U)) || ((info.dilation.x() == info.dilation.y()) && strides.first == 1));
-
- if(weights_type == DataType::QSYMM8_PER_CHANNEL)
- {
- is_dilation_supported = is_dilation_supported && (info.dilation == Size2D(1U, 1U));
- }
-
- return is_input_type_valid && is_weights_type_valid && weights_supported && supported_strides && supported_padding && (info.depth_multiplier == 1) && is_dilation_supported;
+ arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(activation);
+ return act.type != arm_gemm::Activation::Type::None;
}
void CpuDepthwiseConv2dAssemblyDispatch::run(ITensorPack &tensors)
{
- // Prepare assembly kernel
- prepare(tensors);
+ ARM_COMPUTE_ERROR_ON_MSG(tensors.empty(), "No inputs provided");
- auto src = tensors.get_tensor(TensorType::ACL_SRC_0);
- auto workspace = tensors.get_tensor(TensorType::ACL_INT_0);
- auto dst = tensors.get_tensor(TensorType::ACL_DST);
-
- // Setup inputs/outputs
- ARM_COMPUTE_ERROR_ON(workspace == nullptr && workspace->buffer() == nullptr);
- _pImpl->dwc_assembly_kernel->set_working_space(static_cast<void *>(workspace->buffer()));
-
- ARM_COMPUTE_ERROR_ON(workspace->buffer() == nullptr);
- const int input_element_size = src->info()->element_size();
- const int input_batch_stride = src->info()->strides_in_bytes()[3] / input_element_size;
- const int input_row_stride = src->info()->strides_in_bytes().z() / input_element_size;
- const int input_col_stride = src->info()->strides_in_bytes().y() / input_element_size;
- const void *input_ptr = src->buffer() + src->info()->offset_first_element_in_bytes();
- _pImpl->dwc_assembly_kernel->set_input(input_ptr, input_batch_stride, input_row_stride, input_col_stride);
-
- ARM_COMPUTE_ERROR_ON(dst->buffer() == nullptr);
- const int output_element_size = dst->info()->element_size();
- const int output_batch_stride = dst->info()->strides_in_bytes()[3] / output_element_size;
- const int output_row_stride = dst->info()->strides_in_bytes().z() / output_element_size;
- const int output_col_stride = dst->info()->strides_in_bytes().y() / output_element_size;
- void *output_ptr = dst->buffer() + dst->info()->offset_first_element_in_bytes();
- _pImpl->dwc_assembly_kernel->set_output(output_ptr, output_batch_stride, output_row_stride, output_col_stride);
+ prepare(tensors);
- // Schedule assembly kernel
- NEScheduler::get().schedule(&_pImpl->dwc_acl_kernel, Window::DimX);
+ NEScheduler::get().schedule_op(_pImpl->asm_kernel.get(), Window::DimY, _pImpl->asm_kernel->window(), tensors);
}
void CpuDepthwiseConv2dAssemblyDispatch::prepare(ITensorPack &tensors)
{
if(!_pImpl->is_prepared)
{
- auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1);
- auto bias = tensors.get_const_tensor(TensorType::ACL_SRC_2);
- auto packed_weights = tensors.get_tensor(TensorType::ACL_INT_1);
+ // Pack weights and bias
+ const ITensor *weights = tensors.get_const_tensor(TensorType::ACL_SRC_1);
+ const ITensor *bias = tensors.get_const_tensor(TensorType::ACL_SRC_2);
+ ITensor *storage = tensors.get_tensor(TensorType::ACL_INT_1);
- ARM_COMPUTE_ERROR_ON(packed_weights->buffer() == nullptr);
+ const auto weights_ptr = weights->buffer() + weights->info()->offset_first_element_in_bytes();
+ const auto bias_ptr = (bias) ? bias->buffer() + bias->info()->offset_first_element_in_bytes() : nullptr;
+ auto parameters_ptr = storage->buffer() + storage->info()->offset_first_element_in_bytes();
- // Pack weights and bias
- const int weights_element_size = weights->info()->element_size();
- const int weights_row_stride = weights->info()->strides_in_bytes().z() / weights_element_size;
- const int weights_col_stride = weights->info()->strides_in_bytes().y() / weights_element_size;
- _pImpl->dwc_assembly_kernel->pack_params(packed_weights->buffer(),
- weights->buffer() + weights->info()->offset_first_element_in_bytes(),
- weights_row_stride,
- weights_col_stride,
- (bias != nullptr) ? bias->buffer() : nullptr);
- _pImpl->dwc_assembly_kernel->set_packed_params_buffer(packed_weights->buffer());
+ const auto weights_shape = weights->info()->tensor_shape();
+ const auto weights_padding = weights->info()->padding();
+
+ const size_t ld_weights_col = weights_shape[0] + weights_padding.left + weights_padding.right;
+ const size_t ld_weights_row = ld_weights_col * (weights_shape[1] + weights_padding.top + weights_padding.bottom);
+ _pImpl->asm_kernel->pack_parameters(parameters_ptr, bias_ptr, weights_ptr, ld_weights_col, ld_weights_row);
weights->mark_as_unused();
if(bias != nullptr)
diff --git a/src/runtime/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.h b/src/runtime/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.h
index 195942b7fd..70845163f4 100644
--- a/src/runtime/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.h
+++ b/src/runtime/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.h
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_CPU_DEPTHWISECONV2DASSEMBLYDISPATCH_H
-#define ARM_COMPUTE_CPU_DEPTHWISECONV2DASSEMBLYDISPATCH_H
+#ifndef ARM_COMPUTE_CPU_DEPTHWISE_CONV2D_ASSEMBLY_DISPATCH_H
+#define ARM_COMPUTE_CPU_DEPTHWISE_CONV2D_ASSEMBLY_DISPATCH_H
#include "src/core/common/Macros.h"
#include "src/runtime/cpu/ICpuOperator.h"
@@ -40,15 +40,15 @@ public:
ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuDepthwiseConv2dAssemblyDispatch);
/** Default destructor */
~CpuDepthwiseConv2dAssemblyDispatch();
-
/** Initialize the function's source, destination, kernels and border_size.
*
* @note Supports only NHWC format
*
- * @param[in] src Source tensor info. Data type supported: QASYMM8/F16/F32. (Written to only for border filling).
- * @param[in] weights Weights tensor info. These are 3D tensors with shape [W, H, IFM]. Data type supported: Same as @p src.
+ * @param[in] src Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
+ * @param[in] weights Weights tensor info. These are 3D tensors with shape [W, H, IFM].
+ * Data type supported: same as @p src or QASYMM8/QASYMM8_SIGNED/QSYMM8_PER_CHANNEL when @p src is QASYMM8/QASYMM8_SIGNED.
* @param[in] bias (Optional) Biases tensor info. A 1D tensor with shape [IFM]. Must be nullptr if not needed.
- * Data type supported: Same as @p src.
+ * Data type supported: same as @p src or S32 if @p src is quantized.
* @param[out] dst Destination tensor info. Data type supported: same as @p src.
* @param[in] info Depthwise convolution meta-data.
*/
@@ -60,18 +60,13 @@ public:
* @return a status
*/
static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *bias, const ITensorInfo *dst, const ConvolutionInfo &info);
- /** Check if the optimized kernel can be used for the given kernel sizes and strides
- *
- * @warning Even if this return true the inputs and outputs might need to get permuted as the only layout supported is NHWC
+ /** Checks if activation is supported by the assembly kernels
*
- * @param[in] src Input tensor info.
- * @param[in] weights Weights tensor info.
- * @param[in] info Depthwise convolution meta-data.
+ * @param[in] activation Activation to check
*
- * @return True if the assembly kernel could be used else false. Note that transformations of input/output could be needed.
+ * @return True if activation is supported else false
*/
- static bool is_optimized_supported(const ITensorInfo *src, const ITensorInfo *weights, const ConvolutionInfo &info);
-
+ static bool is_activation_supported(const ActivationLayerInfo &activation);
// Inherited methods overridden:
void run(ITensorPack &tensors) override;
void prepare(ITensorPack &tensors) override;
@@ -83,4 +78,4 @@ private:
};
} // namespace cpu
} // namespace arm_compute
-#endif /* ARM_COMPUTE_CPU_DEPTHWISECONV2DASSEMBLYDISPATCH_H */
+#endif /* ARM_COMPUTE_CPU_DEPTHWISE_CONV2D_ASSEMBLY_DISPATCH_H */
diff --git a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index ea3742fee5..1101e05a0d 100644
--- a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -27,6 +27,7 @@
#include "src/core/CPP/Validate.h"
#include "src/core/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h"
#include "src/core/cpu/kernels/assembly/arm_gemm.hpp"
+#include "src/core/utils/AssemblyUtils.h"
#include <arm_neon.h>
#include <cstdlib>
@@ -89,38 +90,6 @@ Params extract_parameters(const ITensorInfo *a, const ITensorInfo *b, const ITen
return p;
}
-arm_gemm::Activation map_to_arm_gemm_activation(const ActivationLayerInfo &act)
-{
- arm_gemm::Activation gemm_act;
-
- // Early exit in case lower bound is other than 0, as it's not yet supported
- if(act.b() != 0.f)
- {
- return gemm_act;
- }
-
- switch(act.activation())
- {
- case ActivationLayerInfo::ActivationFunction::RELU:
- gemm_act.type = arm_gemm::Activation::Type::ReLU;
- break;
- case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU:
- gemm_act.type = arm_gemm::Activation::Type::BoundedReLU;
- gemm_act.param1 = act.a();
- gemm_act.param2 = 0.f;
- break;
- case ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU:
- gemm_act.type = arm_gemm::Activation::Type::BoundedReLU;
- gemm_act.param1 = act.a();
- gemm_act.param2 = act.b();
- break;
- default:
- gemm_act.type = arm_gemm::Activation::Type::None;
- }
-
- return gemm_act;
-}
-
IScheduler::Hints scheduling_hint_heuristic(arm_gemm::GemmMethod method, DataType data_type)
{
// Schedule assembly kernel
@@ -788,14 +757,14 @@ Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo
bool CpuGemmAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation)
{
- arm_gemm::Activation act = map_to_arm_gemm_activation(activation);
+ arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(activation);
return act.type != arm_gemm::Activation::Type::None;
}
void CpuGemmAssemblyDispatch::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, const AsmGemmInfo &info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
- arm_gemm::Activation act = map_to_arm_gemm_activation(info.activation_info);
+ arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(info.activation_info);
//If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured()
if(!CpuGemmAssemblyDispatch::validate(a, b, c, d, info))