aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph/Types.h
diff options
context:
space:
mode:
authorSheri Zhang <sheri.zhang@arm.com>2021-11-02 10:45:07 +0000
committerSheri Zhang <sheri.zhang@arm.com>2021-11-03 17:08:05 +0000
commitfb2280381e7a98ad698ea0c1b2cd635a48ad4acc (patch)
treee3fab3cff60b806e725ba9c771617e41c654604e /arm_compute/graph/Types.h
parentbc788389dcc7bd682f53a85803f6a202d42ac828 (diff)
downloadComputeLibrary-fb2280381e7a98ad698ea0c1b2cd635a48ad4acc.tar.gz
Add graph level convolution fusion with post operator
Resolves: COMPMID-4701 Signed-off-by: Sheri Zhang <sheri.zhang@arm.com> Change-Id: I8a0d3c2ed4bf84489d94b8ae6641d6041aadaee5 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6557 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Reviewed-by: SiCong Li <sicong.li@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/graph/Types.h')
-rw-r--r--arm_compute/graph/Types.h51
1 files changed, 51 insertions, 0 deletions
diff --git a/arm_compute/graph/Types.h b/arm_compute/graph/Types.h
index 63a9433fe6..e802e9dc77 100644
--- a/arm_compute/graph/Types.h
+++ b/arm_compute/graph/Types.h
@@ -62,6 +62,7 @@ using arm_compute::PoolingType;
using arm_compute::PriorBoxLayerInfo;
using arm_compute::DimensionRoundingType;
using arm_compute::InterpolationPolicy;
+using arm_compute::experimental::PostOpType;
using GraphID = unsigned int;
using TensorID = unsigned int;
@@ -145,6 +146,55 @@ enum class FastMathHint
Disabled, /**< Fast math disabled for Convolution layer */
};
+/** Convolution post operator info */
+class ConvPostOpInfo
+{
+public:
+ /** Returns post op type
+ *
+ * @return Post op type
+ */
+ virtual PostOpType type() const = 0;
+ virtual ~ConvPostOpInfo()
+ {
+ }
+};
+
+class ConvPostOpInfoActivation : public ConvPostOpInfo
+{
+public:
+ ConvPostOpInfoActivation(const ActivationLayerInfo &act)
+ : _act(act)
+ {
+ }
+ ~ConvPostOpInfoActivation() override
+ {
+ }
+ PostOpType type() const override
+ {
+ return PostOpType::Activation;
+ }
+ ActivationLayerInfo _act;
+};
+
+class ConvPostOpInfoEltwiseAdd : public ConvPostOpInfo
+{
+public:
+ ConvPostOpInfoEltwiseAdd(int arg_pos, const ConvertPolicy &policy)
+ : _prev_op_dst_pos(arg_pos), _policy(policy)
+ {
+ }
+ PostOpType type() const override
+ {
+ return PostOpType::Eltwise_Add;
+ }
+ ~ConvPostOpInfoEltwiseAdd() override
+ {
+ }
+ int _prev_op_dst_pos;
+ ConvertPolicy _policy;
+};
+
/** Supported nodes */
enum class NodeType
{
@@ -165,6 +215,7 @@ enum class NodeType
FlattenLayer,
FullyConnectedLayer,
FusedConvolutionBatchNormalizationLayer,
+ FusedConvolutionWithPostOp,
FusedDepthwiseConvolutionBatchNormalizationLayer,
GenerateProposalsLayer,
L2NormalizeLayer,