aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph/backends/ValidateHelpers.h
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-08-15 12:14:46 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commit2a2db590fd179dcb8e1a575293cd2b887e2dc246 (patch)
tree5e10da7cb6777f3020b84a2389b279ceef2be5ee /arm_compute/graph/backends/ValidateHelpers.h
parentc1961b51df2e15a01a5950139e81bbd47fbfa627 (diff)
downloadComputeLibrary-2a2db590fd179dcb8e1a575293cd2b887e2dc246.tar.gz
COMPMID-1505: Add native grouping support at graph level
Change-Id: Iedc91b0aee743b59af5140c8acb8124548da3163 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/144362 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Reviewed-by: Michele DiGiorgio <michele.digiorgio@arm.com>
Diffstat (limited to 'arm_compute/graph/backends/ValidateHelpers.h')
-rw-r--r--arm_compute/graph/backends/ValidateHelpers.h49
1 files changed, 26 insertions, 23 deletions
diff --git a/arm_compute/graph/backends/ValidateHelpers.h b/arm_compute/graph/backends/ValidateHelpers.h
index ec84399ac6..3064db20c3 100644
--- a/arm_compute/graph/backends/ValidateHelpers.h
+++ b/arm_compute/graph/backends/ValidateHelpers.h
@@ -107,37 +107,30 @@ Status validate_convolution_layer(ConvolutionLayerNode &node)
const PadStrideInfo conv_info = node.convolution_info();
const ConvolutionMethod conv_algorithm = node.convolution_method();
const bool fast_math = node.fast_math_hint() == FastMathHint::Enabled;
+ const unsigned int num_groups = node.num_groups();
// Validate function
Status status{};
switch(conv_algorithm)
{
case ConvolutionMethod::Direct:
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups != 1, "DirectConvolutionLayer does not support grouping!");
status = DirectConvolutionLayer::validate(input, weights, biases, output, conv_info);
break;
case ConvolutionMethod::GEMM:
- status = GEMMConvolutionLayer::validate(input, weights, biases, output, conv_info);
+ status = GEMMConvolutionLayer::validate(input, weights, biases, output, conv_info,
+ WeightsInfo(), Size2D(1, 1), ActivationLayerInfo(), num_groups);
break;
case ConvolutionMethod::Winograd:
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups != 1, "WinogradConvolutionLayer does not support grouping!");
status = WinogradConvolutionLayer::validate(input, weights, biases, output, conv_info, ActivationLayerInfo(), fast_math);
break;
case ConvolutionMethod::Default:
- status = ConvolutionLayer::validate(input, weights, biases, output, conv_info);
+ status = ConvolutionLayer::validate(input, weights, biases, output, conv_info,
+ WeightsInfo(), Size2D(1, 1), ActivationLayerInfo(), fast_math, num_groups);
break;
default:
- break;
- }
-
- // If validation fails try the Default approach
- if(!bool(status))
- {
- status = ConvolutionLayer::validate(input, weights, biases, output, conv_info /*, fast_math*/);
- if(bool(status))
- {
- ARM_COMPUTE_LOG_GRAPH_INFO("Switched ConvolutionLayer method of node with ID : "
- << node.id() << " and Name: " << node.name() << std::endl);
- node.set_convolution_method(ConvolutionMethod::Default);
- }
+ ARM_COMPUTE_RETURN_ERROR_MSG("Unsupported convolution method");
}
return status;
@@ -160,20 +153,30 @@ Status validate_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node)
ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
// Extract IO and info
- arm_compute::ITensorInfo *weights = detail::get_backing_tensor_info(node.input(1));
+ arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0));
+ arm_compute::ITensorInfo *weights = detail::get_backing_tensor_info(node.input(1));
+ arm_compute::ITensorInfo *biases = get_backing_tensor_info(node.input(2));
+ arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
+
+ const PadStrideInfo conv_info = node.convolution_info();
const DepthwiseConvolutionMethod dwc_algorithm = node.depthwise_convolution_method();
- ARM_COMPUTE_ERROR_ON(weights == nullptr);
- // TODO (geopin01) : Switch when validation is implemented
// Validate function
- if((dwc_algorithm == DepthwiseConvolutionMethod::Optimized3x3) && (weights->tensor_shape()[get_data_layout_dimension_index(weights->data_layout(), DataLayoutDimension::WIDTH)] != 3))
+ Status status{};
+ switch(dwc_algorithm)
{
- ARM_COMPUTE_LOG_GRAPH_INFO("Switched DepthwiseConvolutionLayer method of node with ID : "
- << node.id() << " and Name: " << node.name() << std::endl);
- node.set_depthwise_convolution_method(DepthwiseConvolutionMethod::Default);
+ case DepthwiseConvolutionMethod::Default:
+ case DepthwiseConvolutionMethod::GEMV:
+ status = DepthwiseConvolutionLayer::validate(input, weights, biases, output, conv_info);
+ break;
+ case DepthwiseConvolutionMethod::Optimized3x3:
+ status = DepthwiseConvolutionLayer3x3::validate(input, weights, biases, output, conv_info);
+ break;
+ default:
+ ARM_COMPUTE_RETURN_ERROR_MSG("Unsupported depthwise convolution method");
}
- return Status{};
+ return status;
}
/** Validates a permute layer node