From 2a2db590fd179dcb8e1a575293cd2b887e2dc246 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Wed, 15 Aug 2018 12:14:46 +0100 Subject: COMPMID-1505: Add native grouping support at graph level Change-Id: Iedc91b0aee743b59af5140c8acb8124548da3163 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/144362 Tested-by: Jenkins Reviewed-by: Giorgio Arena Reviewed-by: Michele DiGiorgio --- arm_compute/graph/backends/ValidateHelpers.h | 49 +++++++++++++++------------- 1 file changed, 26 insertions(+), 23 deletions(-) (limited to 'arm_compute/graph/backends/ValidateHelpers.h') diff --git a/arm_compute/graph/backends/ValidateHelpers.h b/arm_compute/graph/backends/ValidateHelpers.h index ec84399ac6..3064db20c3 100644 --- a/arm_compute/graph/backends/ValidateHelpers.h +++ b/arm_compute/graph/backends/ValidateHelpers.h @@ -107,37 +107,30 @@ Status validate_convolution_layer(ConvolutionLayerNode &node) const PadStrideInfo conv_info = node.convolution_info(); const ConvolutionMethod conv_algorithm = node.convolution_method(); const bool fast_math = node.fast_math_hint() == FastMathHint::Enabled; + const unsigned int num_groups = node.num_groups(); // Validate function Status status{}; switch(conv_algorithm) { case ConvolutionMethod::Direct: + ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups != 1, "DirectConvolutionLayer does not support grouping!"); status = DirectConvolutionLayer::validate(input, weights, biases, output, conv_info); break; case ConvolutionMethod::GEMM: - status = GEMMConvolutionLayer::validate(input, weights, biases, output, conv_info); + status = GEMMConvolutionLayer::validate(input, weights, biases, output, conv_info, + WeightsInfo(), Size2D(1, 1), ActivationLayerInfo(), num_groups); break; case ConvolutionMethod::Winograd: + ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups != 1, "WinogradConvolutionLayer does not support grouping!"); status = WinogradConvolutionLayer::validate(input, weights, biases, output, conv_info, ActivationLayerInfo(), fast_math); break; case ConvolutionMethod::Default: - status = ConvolutionLayer::validate(input, weights, biases, output, conv_info); + status = ConvolutionLayer::validate(input, weights, biases, output, conv_info, + WeightsInfo(), Size2D(1, 1), ActivationLayerInfo(), fast_math, num_groups); break; default: - break; - } - - // If validation fails try the Default approach - if(!bool(status)) - { - status = ConvolutionLayer::validate(input, weights, biases, output, conv_info /*, fast_math*/); - if(bool(status)) - { - ARM_COMPUTE_LOG_GRAPH_INFO("Switched ConvolutionLayer method of node with ID : " - << node.id() << " and Name: " << node.name() << std::endl); - node.set_convolution_method(ConvolutionMethod::Default); - } + ARM_COMPUTE_RETURN_ERROR_MSG("Unsupported convolution method"); } return status; @@ -160,20 +153,30 @@ Status validate_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node) ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); // Extract IO and info - arm_compute::ITensorInfo *weights = detail::get_backing_tensor_info(node.input(1)); + arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0)); + arm_compute::ITensorInfo *weights = detail::get_backing_tensor_info(node.input(1)); + arm_compute::ITensorInfo *biases = get_backing_tensor_info(node.input(2)); + arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0)); + + const PadStrideInfo conv_info = node.convolution_info(); const DepthwiseConvolutionMethod dwc_algorithm = node.depthwise_convolution_method(); - ARM_COMPUTE_ERROR_ON(weights == nullptr); - // TODO (geopin01) : Switch when validation is implemented // Validate function - if((dwc_algorithm == DepthwiseConvolutionMethod::Optimized3x3) && (weights->tensor_shape()[get_data_layout_dimension_index(weights->data_layout(), DataLayoutDimension::WIDTH)] != 3)) + Status status{}; + switch(dwc_algorithm) { - ARM_COMPUTE_LOG_GRAPH_INFO("Switched DepthwiseConvolutionLayer method of node with ID : " - << node.id() << " and Name: " << node.name() << std::endl); - node.set_depthwise_convolution_method(DepthwiseConvolutionMethod::Default); + case DepthwiseConvolutionMethod::Default: + case DepthwiseConvolutionMethod::GEMV: + status = DepthwiseConvolutionLayer::validate(input, weights, biases, output, conv_info); + break; + case DepthwiseConvolutionMethod::Optimized3x3: + status = DepthwiseConvolutionLayer3x3::validate(input, weights, biases, output, conv_info); + break; + default: + ARM_COMPUTE_RETURN_ERROR_MSG("Unsupported depthwise convolution method"); } - return Status{}; + return status; } /** Validates a permute layer node -- cgit v1.2.1