aboutsummaryrefslogtreecommitdiff
path: root/src/graph/mutators/NodeFusionMutator.cpp
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2020-09-21 14:22:25 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2020-09-22 09:29:05 +0000
commit047c6fcd2ead657ea251a251893767aa90d6bde3 (patch)
tree44daa3e011e09dc2de5af9d17fb4a18938c2886c /src/graph/mutators/NodeFusionMutator.cpp
parent34654b2d8dcaf268a9d1bf9e0cdb5ba548ced2b7 (diff)
downloadComputeLibrary-047c6fcd2ead657ea251a251893767aa90d6bde3.tar.gz
COMPMID-3791: Add support for all activation types in NodeFusionMutator
Change-Id: I9b548966201c00df8290fea7acf55c2173b0e0aa Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4011 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'src/graph/mutators/NodeFusionMutator.cpp')
-rw-r--r--src/graph/mutators/NodeFusionMutator.cpp19
1 files changed, 10 insertions, 9 deletions
diff --git a/src/graph/mutators/NodeFusionMutator.cpp b/src/graph/mutators/NodeFusionMutator.cpp
index afc4452202..2a80825b36 100644
--- a/src/graph/mutators/NodeFusionMutator.cpp
+++ b/src/graph/mutators/NodeFusionMutator.cpp
@@ -300,10 +300,11 @@ IGraphMutator::MutationType NodeFusionMutator::type() const
void NodeFusionMutator::mutate(Graph &g)
{
// Supported activations when fusing
- const std::set<Activation> supported_fused_activations_conv = { Activation::RELU, Activation::BOUNDED_RELU, Activation::LU_BOUNDED_RELU };
- const std::set<Activation> supported_fused_activations_eltwise = { Activation::RELU, Activation::BOUNDED_RELU, Activation::LU_BOUNDED_RELU,
- Activation::TANH, Activation::LOGISTIC
- };
+ const std::set<Activation> supported_fused_activations = { Activation::ABS, Activation::BOUNDED_RELU, Activation::ELU,
+ Activation::HARD_SWISH, Activation::IDENTITY, Activation::LEAKY_RELU,
+ Activation::LINEAR, Activation::LOGISTIC, Activation::LU_BOUNDED_RELU,
+ Activation::RELU, Activation::SOFT_RELU, Activation::SQRT,
+ Activation::SQUARE, Activation::TANH };
// Preconditions
auto empty_prec = [](INode &)
@@ -328,11 +329,11 @@ void NodeFusionMutator::mutate(Graph &g)
};
// Fusion mutations
- detail::fuse_layer<BatchNormalizationLayerNode, ActivationLayerNode>(g, empty_prec, detail::fuse_node_with_activation<BatchNormalizationLayerNode>, supported_fused_activations_conv);
- detail::fuse_layer<ConvolutionLayerNode, ActivationLayerNode>(g, empty_prec, detail::fuse_node_with_activation<ConvolutionLayerNode>, supported_fused_activations_conv);
- detail::fuse_layer<DepthwiseConvolutionLayerNode, ActivationLayerNode>(g, qs8_prec, detail::fuse_node_with_activation<DepthwiseConvolutionLayerNode>, supported_fused_activations_conv);
- detail::fuse_layer<FullyConnectedLayerNode, ActivationLayerNode>(g, empty_prec, detail::fuse_node_with_activation<FullyConnectedLayerNode>, supported_fused_activations_conv);
- detail::fuse_layer<EltwiseLayerNode, ActivationLayerNode>(g, cl_target_prec, detail::fuse_node_with_activation<EltwiseLayerNode>, supported_fused_activations_eltwise);
+ detail::fuse_layer<BatchNormalizationLayerNode, ActivationLayerNode>(g, empty_prec, detail::fuse_node_with_activation<BatchNormalizationLayerNode>, supported_fused_activations);
+ detail::fuse_layer<ConvolutionLayerNode, ActivationLayerNode>(g, empty_prec, detail::fuse_node_with_activation<ConvolutionLayerNode>, supported_fused_activations);
+ detail::fuse_layer<DepthwiseConvolutionLayerNode, ActivationLayerNode>(g, qs8_prec, detail::fuse_node_with_activation<DepthwiseConvolutionLayerNode>, supported_fused_activations);
+ detail::fuse_layer<FullyConnectedLayerNode, ActivationLayerNode>(g, empty_prec, detail::fuse_node_with_activation<FullyConnectedLayerNode>, supported_fused_activations);
+ detail::fuse_layer<EltwiseLayerNode, ActivationLayerNode>(g, cl_target_prec, detail::fuse_node_with_activation<EltwiseLayerNode>, supported_fused_activations);
detail::fuse_layer<ConvolutionLayerNode, BatchNormalizationLayerNode>(g, empty_prec, detail::fuse_convolution_with_batch_normalization);
detail::fuse_layer<DepthwiseConvolutionLayerNode, BatchNormalizationLayerNode>(g, empty_prec, detail::fuse_depthwise_convolution_with_batch_normalization);
}