diff options
author | Sheri Zhang <sheri.zhang@arm.com> | 2021-11-02 10:45:07 +0000 |
---|---|---|
committer | Sheri Zhang <sheri.zhang@arm.com> | 2021-11-03 17:08:05 +0000 |
commit | fb2280381e7a98ad698ea0c1b2cd635a48ad4acc (patch) | |
tree | e3fab3cff60b806e725ba9c771617e41c654604e /arm_compute/graph/backends/ValidateHelpers.h | |
parent | bc788389dcc7bd682f53a85803f6a202d42ac828 (diff) | |
download | ComputeLibrary-fb2280381e7a98ad698ea0c1b2cd635a48ad4acc.tar.gz |
Add graph level convolution fusion with post operator
Resolves: COMPMID-4701
Signed-off-by: Sheri Zhang <sheri.zhang@arm.com>
Change-Id: I8a0d3c2ed4bf84489d94b8ae6641d6041aadaee5
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6557
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Reviewed-by: SiCong Li <sicong.li@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/graph/backends/ValidateHelpers.h')
-rw-r--r-- | arm_compute/graph/backends/ValidateHelpers.h | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/arm_compute/graph/backends/ValidateHelpers.h b/arm_compute/graph/backends/ValidateHelpers.h index 93d547b036..89dccd88b7 100644 --- a/arm_compute/graph/backends/ValidateHelpers.h +++ b/arm_compute/graph/backends/ValidateHelpers.h @@ -183,6 +183,42 @@ Status validate_convolution_layer(ConvolutionLayerNode &node) return status; } +/** Validates a Convolution layer node + * + * @tparam GEMMConvolutionLayer GEMM Convolution layer function type + * + * @param[in] node Node to validate + * + * @return Status + */ +template <typename GEMMConvolutionLayer> +Status validate_fused_convolution_with_post_op(FusedConvolutionWithPostOpNode &node) +{ + ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating fused ConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 4); + ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); + + // Extract IO and info + arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0)); + arm_compute::ITensorInfo *weights = get_backing_tensor_info(node.input(1)); + arm_compute::ITensorInfo *biases = get_backing_tensor_info(node.input(2)); + arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0)); + + if(is_data_type_quantized_asymmetric(input->data_type())) + { + biases->set_data_type(DataType::S32); + } + + const PadStrideInfo conv_info = node.convolution_info(); + //const ConvolutionMethod conv_algorithm = node.convolution_method(); + //const bool fast_math = node.fast_math_hint() == FastMathHint::Enabled; + const unsigned int num_groups = node.num_groups(); + + // Validate function + return GEMMConvolutionLayer::validate(input, weights, biases, output, conv_info, + WeightsInfo(), Size2D(1, 1), ActivationLayerInfo(), num_groups); +} + /** Validates a Depthwise Convolution layer node * * @tparam DepthwiseConvolutionLayer Default Depthwise Convolution layer type |