diff options
Diffstat (limited to 'src/runtime/NEON/functions/NEConvolutionLayer.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NEConvolutionLayer.cpp | 61 |
1 files changed, 35 insertions, 26 deletions
diff --git a/src/runtime/NEON/functions/NEConvolutionLayer.cpp b/src/runtime/NEON/functions/NEConvolutionLayer.cpp index 901b1e880e..cc5f160787 100644 --- a/src/runtime/NEON/functions/NEConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEConvolutionLayer.cpp @@ -27,27 +27,12 @@ #include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" #include "arm_compute/runtime/NEON/NEScheduler.h" -#include "src/core/NEON/kernels/NECol2ImKernel.h" -#include "src/core/NEON/kernels/NEConvertQuantizedSignednessKernel.h" -#include "src/core/NEON/kernels/NECopyKernel.h" -#include "src/core/NEON/kernels/NEDirectConvolutionLayerKernel.h" -#include "src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h" -#include "src/core/NEON/kernels/NEFFTDigitReverseKernel.h" -#include "src/core/NEON/kernels/NEFFTRadixStageKernel.h" -#include "src/core/NEON/kernels/NEFFTScaleKernel.h" -#include "src/core/NEON/kernels/NEFillBorderKernel.h" -#include "src/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h" -#include "src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h" -#include "src/core/NEON/kernels/NEGEMMLowpOffsetContributionKernel.h" -#include "src/core/NEON/kernels/NEGEMMLowpOffsetContributionOutputStageKernel.h" -#include "src/core/NEON/kernels/NEGEMMLowpReductionKernel.h" -#include "src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.h" -#include "src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h" -#include "src/core/NEON/kernels/NEGEMMTranspose1xWKernel.h" -#include "src/core/NEON/kernels/NEIm2ColKernel.h" -#include "src/core/NEON/kernels/NEPadLayerKernel.h" -#include "src/core/NEON/kernels/NEReductionOperationKernel.h" -#include "src/core/NEON/kernels/NEWeightsReshapeKernel.h" +#include "arm_compute/runtime/NEON/functions/NEDirectConvolutionLayer.h" +#include "arm_compute/runtime/NEON/functions/NEFFTConvolutionLayer.h" +#include "arm_compute/runtime/NEON/functions/NEGEMMConv2d.h" +#include "arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h" +#include "arm_compute/runtime/NEON/functions/NEWinogradConvolutionLayer.h" + #include "support/MemorySupport.h" #include <cmath> @@ -71,6 +56,7 @@ void NEConvolutionLayer::configure(ITensor *input, const ITensor *weights, const ARM_COMPUTE_ERROR_THROW_ON(NEConvolutionLayer::validate(input->info(), weights->info(), ((biases != nullptr) ? biases->info() : nullptr), output->info(), conv_info, weights_info, dilation, act_info, enable_fast_math, num_groups)); + const Conv2dInfo info(conv_info, dilation, act_info, enable_fast_math, num_groups); switch(NEConvolutionLayer::get_convolution_method(input->info(), weights->info(), output->info(), conv_info, weights_info, dilation, act_info, enable_fast_math)) { case ConvolutionMethod::WINOGRAD: @@ -87,6 +73,13 @@ void NEConvolutionLayer::configure(ITensor *input, const ITensor *weights, const _function = std::move(f); break; } + case ConvolutionMethod::GEMM_CONV2D: + { + auto f = arm_compute::support::cpp14::make_unique<NEGEMMConv2d>(_memory_manager); + f->configure(input, weights, biases, output, info); + _function = std::move(f); + break; + } case ConvolutionMethod::DIRECT: { auto f = arm_compute::support::cpp14::make_unique<NEDirectConvolutionLayer>(_memory_manager); @@ -112,22 +105,22 @@ Status NEConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo { ARM_COMPUTE_RETURN_ERROR_ON_MSG((num_groups != 1), "Grouping (num_groups != 1) is not supported on NEON"); + const Conv2dInfo info(conv_info, dilation, act_info, enable_fast_math, num_groups); switch(NEConvolutionLayer::get_convolution_method(input, weights, output, conv_info, weights_info, dilation, act_info, enable_fast_math)) { case ConvolutionMethod::WINOGRAD: - //Validate Winograd ARM_COMPUTE_RETURN_ON_ERROR(NEWinogradConvolutionLayer::validate(input, weights, biases, output, conv_info, act_info, enable_fast_math)); break; case ConvolutionMethod::GEMM: - //Validate Gemm-based Convolution ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMConvolutionLayer::validate(input, weights, biases, output, conv_info, weights_info, dilation, act_info)); break; + case ConvolutionMethod::GEMM_CONV2D: + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMConv2d::validate(input, weights, biases, output, info)); + break; case ConvolutionMethod::DIRECT: - //Validate Direct Convolution ARM_COMPUTE_RETURN_ON_ERROR(NEDirectConvolutionLayer::validate(input, weights, biases, output, conv_info, act_info)); break; case ConvolutionMethod::FFT: - // Validate FFT-based convolution layer ARM_COMPUTE_RETURN_ON_ERROR(NEFFTConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info)); break; default: @@ -149,6 +142,8 @@ ConvolutionMethod NEConvolutionLayer::get_convolution_method(const ITensorInfo * const size_t idx_h = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT); const size_t idx_c = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL); + const Conv2dInfo info(conv_info, dilation, act_info, enable_fast_math, 1); + /* Input spatial dims, kernel size, IFM/OFM, conv info*/ using ConvolutionConfiguration = std::tuple<Size2D, Size2D, Size2D, PadStrideInfo>; using ConfigurationMethod = std::pair<ConvolutionConfiguration, ConvolutionMethod>; @@ -235,7 +230,21 @@ ConvolutionMethod NEConvolutionLayer::get_convolution_method(const ITensorInfo * } } #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - return bool(NEWinogradConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info, enable_fast_math)) ? ConvolutionMethod::WINOGRAD : ConvolutionMethod::GEMM; + // For 1x1 convolutions run the default GEMM + if(weights->dimension(idx_w) == 1 && weights->dimension(idx_h) == 1) + { + return ConvolutionMethod::GEMM; + } + + if(bool(NEWinogradConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info, enable_fast_math))) + { + return ConvolutionMethod::WINOGRAD; + } + if(bool(NEGEMMConv2d::validate(input, weights, nullptr, output, info))) + { + return ConvolutionMethod::GEMM_CONV2D; + } + return ConvolutionMethod::GEMM; } } |