From 8155c0253c00aa9e26651361460c66feb39829a6 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 16 Apr 2021 15:08:59 +0100 Subject: Rework OpenCL Depthwise Convolution - Remove dedicated kernels for NCHW. Now we only use NHWC with permute - Remove specialized kernels for 3x3 NHWC - Simplify CLDepthwiseConvolutionLayer.cpp to call just the native implementation for both floating-point and quantized data types - Develop two parametric opencl kernels for depthwise convolution layer NHWC (floating-point and quantized) - Add support to export the weights to cl_image - Extend test for depthwise convolution on opencl Resolves COMPMID-4417 Change-Id: Ibe533f79c2860f9cac8e921895d5a8f947753a5c Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5893 Reviewed-by: Giorgio Arena Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins --- .../CL/functions/CLDepthwiseConvolutionLayer.cpp | 474 +++++---------------- 1 file changed, 101 insertions(+), 373 deletions(-) (limited to 'src/runtime/CL/functions') diff --git a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp index a826f85c5c..84798fa672 100644 --- a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp @@ -23,16 +23,15 @@ */ #include "arm_compute/runtime/CL/functions/CLDepthwiseConvolutionLayer.h" +#include "arm_compute/core/CL/CLHelpers.h" #include "arm_compute/core/CL/ICLTensor.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/PixelValue.h" +#include "arm_compute/core/Utils.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "arm_compute/runtime/CL/CLScheduler.h" -#include "src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.h" -#include "src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NHWCKernel.h" #include "src/core/CL/kernels/CLDepthwiseConvolutionLayerNativeKernel.h" -#include "src/core/CL/kernels/CLFillBorderKernel.h" namespace arm_compute { @@ -41,70 +40,97 @@ using namespace arm_compute::misc::shape_calculator; namespace { -Status validate_arguments_3x3(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info, - unsigned int depth_multiplier, ActivationLayerInfo act_info, const Size2D &dilation) +bool export_weights_to_cl_image_heuristic(const ITensorInfo *weights, unsigned int depth_multiplier, GPUTarget gpu_target) { - // This function should be removed and incorporated inside CLDepthwiseConvolutionLayerInternal3x3 once CLDepthwiseConvolutionLayer3x3 is properly removed - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN); + if(!export_weights_to_cl_image(weights)) + { + return false; + } - const bool is_quantized = is_data_type_quantized_asymmetric(input->data_type()); - const bool is_nhwc = input->data_layout() == DataLayout::NHWC; - const bool needs_permute = is_nhwc && (depth_multiplier > 1); + const size_t idx_w = get_data_layout_dimension_index(weights->data_layout(), DataLayoutDimension::WIDTH); + const size_t idx_h = get_data_layout_dimension_index(weights->data_layout(), DataLayoutDimension::HEIGHT); + const size_t kernel_w = weights->tensor_shape()[idx_w]; + const size_t kernel_h = weights->tensor_shape()[idx_h]; - ARM_COMPUTE_RETURN_ERROR_ON(is_quantized && is_nhwc && !needs_permute); + if((kernel_w == 1) && (kernel_h == 1)) + { + return false; + } - TensorInfo output_multipliers_shifts_info(TensorInfo(TensorShape(1U), 1, DataType::S32)); - if(is_quantized) + if(depth_multiplier > 1) { - if(is_data_type_quantized_per_channel(weights->data_type())) - { - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QSYMM8_PER_CHANNEL); + return false; + } - const size_t idx_c = get_data_layout_dimension_index(weights->data_layout(), DataLayoutDimension::CHANNEL); - output_multipliers_shifts_info.set_tensor_shape(TensorShape(weights->dimension(idx_c))); + if(gpu_target == GPUTarget::G71 || get_arch_from_target(gpu_target) == GPUTarget::MIDGARD) + { + return false; + } + + return true; +} + +void initialize_dwc_native_compute_info(DWCComputeKernelInfo &dwc_compute_info, const ITensorInfo *weights, const PadStrideInfo &conv_info, const Size2D &dilation, unsigned int depth_multiplier, + GPUTarget gpu_target) +{ + if(!is_data_type_float(weights->data_type())) + { + dwc_compute_info.export_weights_to_cl_image = false; + dwc_compute_info.n0 = (depth_multiplier == 1) ? 4 : 1; + if(conv_info.stride().first == 1 && dilation.x() == 1 && depth_multiplier == 1) + { + dwc_compute_info.m0 = 2; } else { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); + dwc_compute_info.m0 = 1; } - } - if(needs_permute) - { - TensorShape permuted_input_shape = input->tensor_shape(); - TensorShape permuted_weights_shape = weights->tensor_shape(); - const ConvolutionInfo info{ conv_info, depth_multiplier, ActivationLayerInfo(), dilation }; - TensorShape permuted_output_shape = shape_calculator::compute_depthwise_convolution_shape(*input, *weights, info); + return; + } - permute(permuted_input_shape, PermutationVector(1U, 2U, 0U)); - permute(permuted_weights_shape, PermutationVector(1U, 2U, 0U)); - permute(permuted_output_shape, PermutationVector(1U, 2U, 0U)); + // Floating point path - const TensorInfo permuted_input = input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(permuted_input_shape).set_data_layout(DataLayout::NCHW); - const TensorInfo permuted_weights = weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(permuted_weights_shape).set_data_layout(DataLayout::NCHW); - const TensorInfo permuted_output = output->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(permuted_output_shape).set_data_layout(DataLayout::NCHW); + // First check if we can export to cl_image. + dwc_compute_info.export_weights_to_cl_image = export_weights_to_cl_image_heuristic(weights, depth_multiplier, gpu_target); - ARM_COMPUTE_RETURN_ON_ERROR(CLDepthwiseConvolutionLayer3x3NCHWKernel::validate(&permuted_input, &permuted_weights, biases, &permuted_output, - conv_info, depth_multiplier, act_info, - dilation, &output_multipliers_shifts_info, &output_multipliers_shifts_info)); + // Set n0 + if(depth_multiplier == 1) + { + if(dwc_compute_info.export_weights_to_cl_image == false && weights->data_type() == DataType::F16) + { + dwc_compute_info.n0 = 8; + } + else + { + dwc_compute_info.n0 = 4; + } } - else if(is_nhwc) + else { - ARM_COMPUTE_RETURN_ON_ERROR(CLDepthwiseConvolutionLayer3x3NHWCKernel::validate(input, weights, biases, output, conv_info, depth_multiplier, act_info, - dilation)); + dwc_compute_info.n0 = 1; + } + + dwc_compute_info.n0 = adjust_vec_size(dwc_compute_info.n0, weights->dimension(0)); + + // Set m0 only if stride_x == 1 and dilation_x == 1 + if(conv_info.stride().first == 1 && dilation.x() == 1) + { + const size_t idx_w = get_data_layout_dimension_index(weights->data_layout(), DataLayoutDimension::WIDTH); + const size_t kernel_w = weights->tensor_shape()[idx_w]; + + dwc_compute_info.m0 = (kernel_w >= 9) || (kernel_w == 1) ? 1 : 2; } else { - ARM_COMPUTE_RETURN_ON_ERROR(CLDepthwiseConvolutionLayer3x3NCHWKernel::validate(input, weights, biases, output, conv_info, depth_multiplier, act_info, - dilation, &output_multipliers_shifts_info, &output_multipliers_shifts_info)); + dwc_compute_info.m0 = 1; } - return Status{}; + return; } + } // namespace -CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayerGeneric::CLDepthwiseConvolutionLayerGeneric(std::shared_ptr memory_manager) +CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayer(std::shared_ptr memory_manager) : _memory_group(std::move(memory_manager)), _dwc_native_kernel(std::make_unique()), _permute_input_to_nhwc(), @@ -126,15 +152,15 @@ CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayerGeneric::CLDepthwiseConv CLDepthwiseConvolutionLayer::~CLDepthwiseConvolutionLayer() = default; -void CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayerGeneric::configure(ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info, - unsigned int depth_multiplier, const ActivationLayerInfo &act_info, const Size2D &dilation) +void CLDepthwiseConvolutionLayer::configure(ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info, + unsigned int depth_multiplier, ActivationLayerInfo act_info, const Size2D &dilation) { configure(CLKernelLibrary::get().get_compile_context(), input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation); } -void CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayerGeneric::configure(const CLCompileContext &compile_context, ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, - ICLTensor *output, const PadStrideInfo &conv_info, - unsigned int depth_multiplier, const ActivationLayerInfo &act_info, const Size2D &dilation) +void CLDepthwiseConvolutionLayer::configure(const CLCompileContext &compile_context, ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, + ICLTensor *output, const PadStrideInfo &conv_info, + unsigned int depth_multiplier, ActivationLayerInfo act_info, const Size2D &dilation) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights); ARM_COMPUTE_ERROR_THROW_ON(CLDepthwiseConvolutionLayer::validate(input->info(), @@ -153,6 +179,8 @@ void CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayerGeneric::configure( _output = output; _needs_permute = input->info()->data_layout() == DataLayout::NCHW; + const GPUTarget gpu_target = CLScheduler::get().target(); + ICLTensor *input_to_use = input; const ICLTensor *weights_to_use = weights; ICLTensor *output_to_use = output; @@ -191,13 +219,13 @@ void CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayerGeneric::configure( output_shifts_to_use = &_output_shifts; } - DWCWeightsKernelInfo dwc_weights_info; - dwc_weights_info.n0 = (depth_multiplier == 1) ? 8 : 1; - DWCKernelInfo dwc_info; - dwc_info.activation_info = act_info; + DWCComputeKernelInfo dwc_native_compute_info; + initialize_dwc_native_compute_info(dwc_native_compute_info, weights_to_use->info(), conv_info, dilation, depth_multiplier, gpu_target); + + const ConvolutionInfo conv_kernel_info{ conv_info, depth_multiplier, act_info, dilation }; + _dwc_native_kernel->configure(compile_context, input_to_use, weights_to_use, biases, output_to_use, - dwc_weights_info, dwc_info, conv_info, depth_multiplier, dilation, - output_multipliers_to_use, output_shifts_to_use); + dwc_native_compute_info, conv_kernel_info, output_multipliers_to_use, output_shifts_to_use); if(_needs_permute) { @@ -216,9 +244,9 @@ void CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayerGeneric::configure( } } -Status CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayerGeneric::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, - const PadStrideInfo &conv_info, - unsigned int depth_multiplier, const ActivationLayerInfo &act_info, const Size2D &dilation) +Status CLDepthwiseConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, + const PadStrideInfo &conv_info, + unsigned int depth_multiplier, ActivationLayerInfo act_info, const Size2D &dilation) { const bool in_place = input == output || output == nullptr; if(in_place) @@ -232,10 +260,9 @@ Status CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayerGeneric::validate ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_w) + (weights->dimension(idx_w) - 1) * (dilation.x() - 1) > input->dimension(idx_w) + conv_info.pad_left() + conv_info.pad_right()); ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_h) + (weights->dimension(idx_h) - 1) * (dilation.y() - 1) > input->dimension(idx_h) + conv_info.pad_top() + conv_info.pad_bottom()); - DWCWeightsKernelInfo dwc_weights_info; - dwc_weights_info.n0 = (depth_multiplier == 1) ? 8 : 1; - DWCKernelInfo dwc_info; - dwc_info.activation_info = act_info; + const GPUTarget gpu_target = CLScheduler::get().target(); + + const ConvolutionInfo conv_kernel_info{ conv_info, depth_multiplier, act_info, dilation }; const bool needs_permute = input->data_layout() == DataLayout::NCHW; @@ -275,20 +302,25 @@ Status CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayerGeneric::validate ARM_COMPUTE_RETURN_ON_ERROR(CLPermute::validate(input, &permuted_input, PermutationVector(2U, 0U, 1U))); ARM_COMPUTE_RETURN_ON_ERROR(CLPermute::validate(weights, &permuted_weights, PermutationVector(2U, 0U, 1U))); - ARM_COMPUTE_RETURN_ON_ERROR(CLDepthwiseConvolutionLayerNativeKernel::validate(&permuted_input, &permuted_weights, biases, &permuted_output, dwc_weights_info, - dwc_info, conv_info, depth_multiplier, dilation, - &output_multipliers_shifts_info, &output_multipliers_shifts_info)); + + DWCComputeKernelInfo dwc_native_compute_info; + initialize_dwc_native_compute_info(dwc_native_compute_info, &permuted_weights, conv_info, dilation, depth_multiplier, gpu_target); + + ARM_COMPUTE_RETURN_ON_ERROR(CLDepthwiseConvolutionLayerNativeKernel::validate(&permuted_input, &permuted_weights, biases, &permuted_output, + dwc_native_compute_info, conv_kernel_info, &output_multipliers_shifts_info, &output_multipliers_shifts_info)); ARM_COMPUTE_RETURN_ON_ERROR(CLPermute::validate(&permuted_output, output, PermutationVector(1U, 2U, 0U))); } else { - ARM_COMPUTE_RETURN_ON_ERROR(CLDepthwiseConvolutionLayerNativeKernel::validate(input, weights, biases, output, dwc_weights_info, dwc_info, conv_info, depth_multiplier, - dilation, &output_multipliers_shifts_info, &output_multipliers_shifts_info)); + DWCComputeKernelInfo dwc_native_compute_info; + initialize_dwc_native_compute_info(dwc_native_compute_info, weights, conv_info, dilation, depth_multiplier, gpu_target); + ARM_COMPUTE_RETURN_ON_ERROR(CLDepthwiseConvolutionLayerNativeKernel::validate(input, weights, biases, output, dwc_native_compute_info, conv_kernel_info, &output_multipliers_shifts_info, + &output_multipliers_shifts_info)); } return Status{}; } -void CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayerGeneric::run() +void CLDepthwiseConvolutionLayer::run() { prepare(); @@ -305,7 +337,7 @@ void CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayerGeneric::run() } } -void CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayerGeneric::prepare() +void CLDepthwiseConvolutionLayer::prepare() { if(!_is_prepared) { @@ -333,308 +365,4 @@ void CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayerGeneric::prepare() _is_prepared = true; } } - -CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayerInternal3x3::CLDepthwiseConvolutionLayerInternal3x3(std::shared_ptr memory_manager) - : _memory_group(std::move(memory_manager)), - _kernel_nchw(nullptr), - _kernel_nhwc(nullptr), - _border_handler(std::make_unique()), - _permute_input_to_nchw(), - _permute_weights_to_nchw(), - _permute_output_to_nhwc(), - _permuted_input(), - _permuted_weights(), - _permuted_output(), - _output_multipliers(), - _output_shifts(), - _original_weights(nullptr), - _input(nullptr), - _output(nullptr), - _needs_permute(false), - _is_prepared(false), - _is_quantized(false), - _is_nhwc(false) -{ -} - -void CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayerInternal3x3::configure(ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, - const PadStrideInfo &conv_info, unsigned int depth_multiplier, ActivationLayerInfo act_info, const Size2D &dilation) -{ - configure(CLKernelLibrary::get().get_compile_context(), input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation); -} - -void CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayerInternal3x3::configure(const CLCompileContext &compile_context, ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, - ICLTensor *output, - const PadStrideInfo &conv_info, unsigned int depth_multiplier, ActivationLayerInfo act_info, const Size2D &dilation) -{ - // Perform validation step - ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); - ARM_COMPUTE_ERROR_THROW_ON(CLDepthwiseConvolutionLayerInternal3x3::validate(input->info(), - weights->info(), - biases != nullptr ? biases->info() : nullptr, - output->info(), - conv_info, - depth_multiplier, - act_info, - dilation)); - - _is_nhwc = input->info()->data_layout() == DataLayout::NHWC; - _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type()); - _needs_permute = _is_nhwc && (depth_multiplier > 1); - - _is_prepared = false; - _original_weights = weights; - _input = input; - _output = output; - - ICLTensor *input_to_use = input; - const ICLTensor *weights_to_use = weights; - ICLTensor *output_to_use = output; - - const bool is_quantized_per_channel = is_data_type_quantized_per_channel(weights->info()->data_type()); - - if(_needs_permute) - { - _memory_group.manage(&_permuted_input); - _memory_group.manage(&_permuted_output); - - // Configure the function to transform the input tensor from NHWC -> NCHW - _permute_input_to_nchw.configure(compile_context, input, &_permuted_input, PermutationVector(1U, 2U, 0U)); - _permuted_input.info()->set_data_layout(DataLayout::NCHW); - - // Configure the function to transform the weights tensor from HWI -> IHW - _permute_weights_to_nchw.configure(compile_context, weights, &_permuted_weights, PermutationVector(1U, 2U, 0U)); - _permuted_weights.info()->set_data_layout(DataLayout::NCHW); - _permuted_output.info()->set_quantization_info(output->info()->quantization_info()); - - input_to_use = &_permuted_input; - weights_to_use = &_permuted_weights; - output_to_use = &_permuted_output; - - _kernel_nchw = std::make_unique(); - } - else if(_is_nhwc) - { - _kernel_nhwc = std::make_unique(); - } - else - { - _kernel_nchw = std::make_unique(); - } - - CLTensor *output_multipliers_to_use = nullptr; - CLTensor *output_shifts_to_use = nullptr; - if(_is_quantized) - { - const size_t idx_c = get_data_layout_dimension_index(weights->info()->data_layout(), DataLayoutDimension::CHANNEL); - const size_t num_filters = (is_quantized_per_channel) ? weights->info()->dimension(idx_c) : 1; - - _output_multipliers.allocator()->init(TensorInfo(TensorShape(num_filters), 1, DataType::S32)); - _output_shifts.allocator()->init(TensorInfo(TensorShape(num_filters), 1, DataType::S32)); - - output_multipliers_to_use = &_output_multipliers; - output_shifts_to_use = &_output_shifts; - } - - // Configure kernel - if(_is_nhwc && !_needs_permute) - { - _kernel_nhwc->configure(compile_context, input_to_use, weights_to_use, biases, output_to_use, conv_info, depth_multiplier, - act_info, dilation); - } - else - { - _kernel_nchw->configure(compile_context, input_to_use, weights_to_use, biases, output_to_use, conv_info, depth_multiplier, - act_info, dilation, output_multipliers_to_use, output_shifts_to_use); - } - - if(_is_quantized) - { - _output_multipliers.allocator()->allocate(); - _output_shifts.allocator()->allocate(); - } - - // Permute output if needed - if(_needs_permute) - { - // Configure the function to transform the convoluted output to ACL's native ordering format NCHW - _permuted_output.info()->set_data_layout(DataLayout::NCHW); - _permute_output_to_nhwc.configure(compile_context, &_permuted_output, output, PermutationVector(2U, 0U, 1U)); - - // Allocate tensors - _permuted_input.allocator()->allocate(); - _permuted_output.allocator()->allocate(); - } - // Configure border handler - PixelValue &&zero_value(0.f); - if(is_data_type_quantized_asymmetric(input->info()->data_type())) - { - zero_value = PixelValue(static_cast(input->info()->quantization_info().uniform().offset)); - } - if(!_is_nhwc || _needs_permute) - { - _border_handler->configure(compile_context, input_to_use, _kernel_nchw->border_size(), BorderMode::CONSTANT, zero_value); - } -} - -Status CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayerInternal3x3::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, - const PadStrideInfo &conv_info, unsigned int depth_multiplier, ActivationLayerInfo act_info, const Size2D &dilation) -{ - return validate_arguments_3x3(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation); -} - -void CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayerInternal3x3::run() -{ - prepare(); - - MemoryGroupResourceScope scope_mg(_memory_group); - - if(_needs_permute) - { - _permute_input_to_nchw.run(); - } - CLScheduler::get().enqueue(*_border_handler); - if(_is_nhwc && !_needs_permute) - { - CLScheduler::get().enqueue(*_kernel_nhwc); - } - else - { - CLScheduler::get().enqueue(*_kernel_nchw); - } - - if(_needs_permute) - { - _permute_output_to_nhwc.run(); - } -} - -void CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayerInternal3x3::prepare() -{ - if(!_is_prepared) - { - if(_is_quantized) - { - _output_multipliers.map(); - _output_shifts.map(); - quantization::compute_quantized_multipliers_and_shifts(_input->info(), - _original_weights->info(), - _output->info(), - reinterpret_cast(_output_multipliers.ptr_to_element(Coordinates(0))), - reinterpret_cast(_output_shifts.ptr_to_element(Coordinates(0)))); - _output_multipliers.unmap(); - _output_shifts.unmap(); - } - - if(_needs_permute) - { - ARM_COMPUTE_ERROR_ON(!_original_weights->is_used()); - - _permuted_weights.allocator()->allocate(); - _permute_weights_to_nchw.run(); - _original_weights->mark_as_unused(); - } - - _is_prepared = true; - } -} - -CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayer(std::shared_ptr memory_manager) - : _memory_manager(std::move(memory_manager)), _depth_conv_func(DepthwiseConvolutionFunction::GENERIC), _func_3x3(), _func_generic() -{ -} - -void CLDepthwiseConvolutionLayer::configure(ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier, - ActivationLayerInfo act_info, const Size2D &dilation) -{ - configure(CLKernelLibrary::get().get_compile_context(), input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation); -} - -void CLDepthwiseConvolutionLayer::configure(const CLCompileContext &compile_context, ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, - const PadStrideInfo &conv_info, - unsigned int depth_multiplier, - ActivationLayerInfo act_info, const Size2D &dilation) -{ - if(output == nullptr) - { - // In-place - output = input; - } - _depth_conv_func = get_depthwiseconvolution_function(input->info(), weights->info(), (biases != nullptr) ? biases->info() : nullptr, output->info(), conv_info, depth_multiplier, act_info, - dilation); - switch(_depth_conv_func) - { - case DepthwiseConvolutionFunction::OPTIMIZED: - _func_3x3.set_memory_group(_memory_manager); - _func_3x3.configure(compile_context, input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation); - break; - case DepthwiseConvolutionFunction::GENERIC: - { - _func_generic.set_memory_group(_memory_manager); - _func_generic.configure(compile_context, input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation); - } - break; - default: - ARM_COMPUTE_ERROR("Unsupported DepthwiseConvolutionFunction"); - } -} - -Status CLDepthwiseConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info, - unsigned int depth_multiplier, ActivationLayerInfo act_info, const Size2D &dilation) -{ - DepthwiseConvolutionFunction depth_conv_func = get_depthwiseconvolution_function(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation); - switch(depth_conv_func) - { - case DepthwiseConvolutionFunction::OPTIMIZED: - return CLDepthwiseConvolutionLayerInternal3x3::validate(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation); - case DepthwiseConvolutionFunction::GENERIC: - return CLDepthwiseConvolutionLayerGeneric::validate(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation); - default: - ARM_COMPUTE_ERROR("Unsupported DepthwiseConvolutionFunction"); - } -} - -DepthwiseConvolutionFunction CLDepthwiseConvolutionLayer::get_depthwiseconvolution_function(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, - const PadStrideInfo &conv_info, - unsigned int depth_multiplier, ActivationLayerInfo act_info, const Size2D &dilation) -{ - if(bool(CLDepthwiseConvolutionLayerInternal3x3::validate(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation))) - { - return DepthwiseConvolutionFunction::OPTIMIZED; - } - else - { - return DepthwiseConvolutionFunction::GENERIC; - } -} - -void CLDepthwiseConvolutionLayer::run() -{ - switch(_depth_conv_func) - { - case DepthwiseConvolutionFunction::OPTIMIZED: - _func_3x3.run(); - break; - case DepthwiseConvolutionFunction::GENERIC: - _func_generic.run(); - break; - default: - ARM_COMPUTE_ERROR("DepthwiseConvolutionFunction not properly configured"); - } -} - -void CLDepthwiseConvolutionLayer::prepare() -{ - switch(_depth_conv_func) - { - case DepthwiseConvolutionFunction::OPTIMIZED: - _func_3x3.prepare(); - break; - case DepthwiseConvolutionFunction::GENERIC: - _func_generic.prepare(); - break; - default: - ARM_COMPUTE_ERROR("DepthwiseConvolutionFunction not properly configured"); - } -} } // namespace arm_compute -- cgit v1.2.1