aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Network.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Network.cpp')
-rw-r--r--src/armnn/Network.cpp17
1 files changed, 10 insertions, 7 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index e81b87b382..08d3280cfe 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017,2022,2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017, 2022-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -1056,7 +1056,7 @@ OptimizationResult ApplyBackendOptimizations(OptimizedNetworkImpl* optNetObjPtr,
auto backendObjPtr = backends.find(selectedBackend)->second.get();
ARMNN_ASSERT(backendObjPtr);
- if(selectedBackend == armnn::Compute::GpuAcc || selectedBackend == armnn::Compute::CpuAcc)
+ if (selectedBackend == armnn::Compute::GpuAcc || selectedBackend == armnn::Compute::CpuAcc)
{
Optimizer::Pass(optGraph, MakeOptimizations(optimizations::PermuteDepthwiseConv2dWeights()));
Optimizer::Pass(optGraph, MakeOptimizations(optimizations::FusePermuteIntoConstLayer()));
@@ -1636,10 +1636,14 @@ IOptimizedNetworkPtr Optimize(const Graph& inGraph,
optGraph.InferTensorInfos();
}
- // Need to FusePermuteIntoConstantLayer before FoldPadIntoDepthwiseConvolution2d or
- // FuseBatchNormIntoDepthwiseConvolution2D optimizations are called.
- Optimizer::Pass(optGraph, MakeOptimizations(FusePermuteIntoConstLayer()));
+ // Group Constant Layer optimizations together where possible.
+ // This is important as:
+ // FusePermuteIntoConstantLayer must happen before FoldPadIntoDepthwiseConvolution2d and
+ // FuseBatchNormIntoDepthwiseConvolution2D.
+ // ConvertConstDequantisationLayersToConstLayers must happen before FoldPadIntoConvolution2d
+ Optimizer::Pass(optGraph, MakeOptimizations(FusePermuteIntoConstLayer(),
+ ConvertConstDequantisationLayersToConstLayers()));
// Perform optimisation passes
Optimizer::Pass(optGraph, MakeOptimizations(SquashEqualPermuteSiblings(),
SquashEqualTransposeSiblings(),
@@ -1659,8 +1663,7 @@ IOptimizedNetworkPtr Optimize(const Graph& inGraph,
FuseBatchNormIntoConvolution2DFloat32(),
FuseBatchNormIntoConvolution2DFloat16(),
FuseBatchNormIntoDepthwiseConvolution2DFloat32(),
- FuseBatchNormIntoDepthwiseConvolution2DFloat16(),
- ConvertConstDequantisationLayersToConstLayers()));
+ FuseBatchNormIntoDepthwiseConvolution2DFloat16()));
// If Fp32 to Fp16 optimization is set convert Fp32 network to Fp16
if (options.m_ReduceFp32ToFp16)