From fb2280381e7a98ad698ea0c1b2cd635a48ad4acc Mon Sep 17 00:00:00 2001 From: Sheri Zhang Date: Tue, 2 Nov 2021 10:45:07 +0000 Subject: Add graph level convolution fusion with post operator Resolves: COMPMID-4701 Signed-off-by: Sheri Zhang Change-Id: I8a0d3c2ed4bf84489d94b8ae6641d6041aadaee5 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6557 Tested-by: Arm Jenkins Reviewed-by: Gunes Bayir Reviewed-by: SiCong Li Comments-Addressed: Arm Jenkins --- arm_compute/graph/backends/ValidateHelpers.h | 36 ++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) (limited to 'arm_compute/graph/backends/ValidateHelpers.h') diff --git a/arm_compute/graph/backends/ValidateHelpers.h b/arm_compute/graph/backends/ValidateHelpers.h index 93d547b036..89dccd88b7 100644 --- a/arm_compute/graph/backends/ValidateHelpers.h +++ b/arm_compute/graph/backends/ValidateHelpers.h @@ -183,6 +183,42 @@ Status validate_convolution_layer(ConvolutionLayerNode &node) return status; } +/** Validates a Convolution layer node + * + * @tparam GEMMConvolutionLayer GEMM Convolution layer function type + * + * @param[in] node Node to validate + * + * @return Status + */ +template +Status validate_fused_convolution_with_post_op(FusedConvolutionWithPostOpNode &node) +{ + ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating fused ConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 4); + ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); + + // Extract IO and info + arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0)); + arm_compute::ITensorInfo *weights = get_backing_tensor_info(node.input(1)); + arm_compute::ITensorInfo *biases = get_backing_tensor_info(node.input(2)); + arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0)); + + if(is_data_type_quantized_asymmetric(input->data_type())) + { + biases->set_data_type(DataType::S32); + } + + const PadStrideInfo conv_info = node.convolution_info(); + //const ConvolutionMethod conv_algorithm = node.convolution_method(); + //const bool fast_math = node.fast_math_hint() == FastMathHint::Enabled; + const unsigned int num_groups = node.num_groups(); + + // Validate function + return GEMMConvolutionLayer::validate(input, weights, biases, output, conv_info, + WeightsInfo(), Size2D(1, 1), ActivationLayerInfo(), num_groups); +} + /** Validates a Depthwise Convolution layer node * * @tparam DepthwiseConvolutionLayer Default Depthwise Convolution layer type -- cgit v1.2.1