From cadb368b0827601647c3d1fd66689f96473af5cb Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 29 Mar 2019 10:54:36 +0000 Subject: COMPMID-1995: Fixed graph fusion mutator for float types. -Fixes precondition checks for fusing activation with other nodes. -Fixes is_relu6 check to capture bounded relu as well. Change-Id: Iba193af51491b537c884a35ca85172151534f3ec Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/918 Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- arm_compute/core/utils/misc/InfoHelpers.h | 8 +++++--- src/graph/mutators/NodeFusionMutator.cpp | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/arm_compute/core/utils/misc/InfoHelpers.h b/arm_compute/core/utils/misc/InfoHelpers.h index 704e178292..8197862700 100644 --- a/arm_compute/core/utils/misc/InfoHelpers.h +++ b/arm_compute/core/utils/misc/InfoHelpers.h @@ -52,9 +52,11 @@ inline bool is_relu(ActivationLayerInfo activation_info) */ inline bool is_relu6(ActivationLayerInfo activation_info) { - return activation_info.enabled() - && activation_info.activation() == ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU - && activation_info.a() == 6.f && activation_info.b() == 0.f; + const bool is_lu_bounded_relu = activation_info.activation() == ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU + && activation_info.a() == 6.f && activation_info.b() == 0.f; + const bool is_bounded_relu = activation_info.activation() == ActivationLayerInfo::ActivationFunction::BOUNDED_RELU + && activation_info.a() == 6.f; + return activation_info.enabled() && (is_lu_bounded_relu || is_bounded_relu); } } // namespace info_helpers } // namespace utils diff --git a/src/graph/mutators/NodeFusionMutator.cpp b/src/graph/mutators/NodeFusionMutator.cpp index 724307e7b7..b28f2dbd2e 100644 --- a/src/graph/mutators/NodeFusionMutator.cpp +++ b/src/graph/mutators/NodeFusionMutator.cpp @@ -221,7 +221,7 @@ void NodeFusionMutator::mutate(Graph &g) const bool same_qinfo = n.output(0)->desc().quant_info == output_edge->producer()->output(0)->desc().quant_info; const bool output_qasymm8 = n.output(0)->desc().data_type == DataType::QASYMM8; - return output_qasymm8 && same_qinfo; + return (output_qasymm8 && same_qinfo) || !output_qasymm8; }; // Fusion mutations -- cgit v1.2.1