aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph/backends/ValidateHelpers.h
diff options
context:
space:
mode:
authorSheri Zhang <sheri.zhang@arm.com>2021-11-02 10:45:07 +0000
committerSheri Zhang <sheri.zhang@arm.com>2021-11-03 17:08:05 +0000
commitfb2280381e7a98ad698ea0c1b2cd635a48ad4acc (patch)
treee3fab3cff60b806e725ba9c771617e41c654604e /arm_compute/graph/backends/ValidateHelpers.h
parentbc788389dcc7bd682f53a85803f6a202d42ac828 (diff)
downloadComputeLibrary-fb2280381e7a98ad698ea0c1b2cd635a48ad4acc.tar.gz
Add graph level convolution fusion with post operator
Resolves: COMPMID-4701 Signed-off-by: Sheri Zhang <sheri.zhang@arm.com> Change-Id: I8a0d3c2ed4bf84489d94b8ae6641d6041aadaee5 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6557 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Reviewed-by: SiCong Li <sicong.li@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/graph/backends/ValidateHelpers.h')
-rw-r--r--arm_compute/graph/backends/ValidateHelpers.h36
1 files changed, 36 insertions, 0 deletions
diff --git a/arm_compute/graph/backends/ValidateHelpers.h b/arm_compute/graph/backends/ValidateHelpers.h
index 93d547b036..89dccd88b7 100644
--- a/arm_compute/graph/backends/ValidateHelpers.h
+++ b/arm_compute/graph/backends/ValidateHelpers.h
@@ -183,6 +183,42 @@ Status validate_convolution_layer(ConvolutionLayerNode &node)
return status;
}
+/** Validates a Convolution layer node
+ *
+ * @tparam GEMMConvolutionLayer GEMM Convolution layer function type
+ *
+ * @param[in] node Node to validate
+ *
+ * @return Status
+ */
+template <typename GEMMConvolutionLayer>
+Status validate_fused_convolution_with_post_op(FusedConvolutionWithPostOpNode &node)
+{
+ ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating fused ConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
+ ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 4);
+ ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
+
+ // Extract IO and info
+ arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
+ arm_compute::ITensorInfo *weights = get_backing_tensor_info(node.input(1));
+ arm_compute::ITensorInfo *biases = get_backing_tensor_info(node.input(2));
+ arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
+
+ if(is_data_type_quantized_asymmetric(input->data_type()))
+ {
+ biases->set_data_type(DataType::S32);
+ }
+
+ const PadStrideInfo conv_info = node.convolution_info();
+ //const ConvolutionMethod conv_algorithm = node.convolution_method();
+ //const bool fast_math = node.fast_math_hint() == FastMathHint::Enabled;
+ const unsigned int num_groups = node.num_groups();
+
+ // Validate function
+ return GEMMConvolutionLayer::validate(input, weights, biases, output, conv_info,
+ WeightsInfo(), Size2D(1, 1), ActivationLayerInfo(), num_groups);
+}
+
/** Validates a Depthwise Convolution layer node
*
* @tparam DepthwiseConvolutionLayer Default Depthwise Convolution layer type