From fb2280381e7a98ad698ea0c1b2cd635a48ad4acc Mon Sep 17 00:00:00 2001 From: Sheri Zhang Date: Tue, 2 Nov 2021 10:45:07 +0000 Subject: Add graph level convolution fusion with post operator Resolves: COMPMID-4701 Signed-off-by: Sheri Zhang Change-Id: I8a0d3c2ed4bf84489d94b8ae6641d6041aadaee5 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6557 Tested-by: Arm Jenkins Reviewed-by: Gunes Bayir Reviewed-by: SiCong Li Comments-Addressed: Arm Jenkins --- arm_compute/graph/Types.h | 51 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) (limited to 'arm_compute/graph/Types.h') diff --git a/arm_compute/graph/Types.h b/arm_compute/graph/Types.h index 63a9433fe6..e802e9dc77 100644 --- a/arm_compute/graph/Types.h +++ b/arm_compute/graph/Types.h @@ -62,6 +62,7 @@ using arm_compute::PoolingType; using arm_compute::PriorBoxLayerInfo; using arm_compute::DimensionRoundingType; using arm_compute::InterpolationPolicy; +using arm_compute::experimental::PostOpType; using GraphID = unsigned int; using TensorID = unsigned int; @@ -145,6 +146,55 @@ enum class FastMathHint Disabled, /**< Fast math disabled for Convolution layer */ }; +/** Convolution post operator info */ +class ConvPostOpInfo +{ +public: + /** Returns post op type + * + * @return Post op type + */ + virtual PostOpType type() const = 0; + virtual ~ConvPostOpInfo() + { + } +}; + +class ConvPostOpInfoActivation : public ConvPostOpInfo +{ +public: + ConvPostOpInfoActivation(const ActivationLayerInfo &act) + : _act(act) + { + } + ~ConvPostOpInfoActivation() override + { + } + PostOpType type() const override + { + return PostOpType::Activation; + } + ActivationLayerInfo _act; +}; + +class ConvPostOpInfoEltwiseAdd : public ConvPostOpInfo +{ +public: + ConvPostOpInfoEltwiseAdd(int arg_pos, const ConvertPolicy &policy) + : _prev_op_dst_pos(arg_pos), _policy(policy) + { + } + PostOpType type() const override + { + return PostOpType::Eltwise_Add; + } + ~ConvPostOpInfoEltwiseAdd() override + { + } + int _prev_op_dst_pos; + ConvertPolicy _policy; +}; + /** Supported nodes */ enum class NodeType { @@ -165,6 +215,7 @@ enum class NodeType FlattenLayer, FullyConnectedLayer, FusedConvolutionBatchNormalizationLayer, + FusedConvolutionWithPostOp, FusedDepthwiseConvolutionBatchNormalizationLayer, GenerateProposalsLayer, L2NormalizeLayer, -- cgit v1.2.1