diff options
Diffstat (limited to 'src/armnn/Network.cpp')
-rw-r--r-- | src/armnn/Network.cpp | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index e81b87b382..08d3280cfe 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017,2022,2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017, 2022-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -1056,7 +1056,7 @@ OptimizationResult ApplyBackendOptimizations(OptimizedNetworkImpl* optNetObjPtr, auto backendObjPtr = backends.find(selectedBackend)->second.get(); ARMNN_ASSERT(backendObjPtr); - if(selectedBackend == armnn::Compute::GpuAcc || selectedBackend == armnn::Compute::CpuAcc) + if (selectedBackend == armnn::Compute::GpuAcc || selectedBackend == armnn::Compute::CpuAcc) { Optimizer::Pass(optGraph, MakeOptimizations(optimizations::PermuteDepthwiseConv2dWeights())); Optimizer::Pass(optGraph, MakeOptimizations(optimizations::FusePermuteIntoConstLayer())); @@ -1636,10 +1636,14 @@ IOptimizedNetworkPtr Optimize(const Graph& inGraph, optGraph.InferTensorInfos(); } - // Need to FusePermuteIntoConstantLayer before FoldPadIntoDepthwiseConvolution2d or - // FuseBatchNormIntoDepthwiseConvolution2D optimizations are called. - Optimizer::Pass(optGraph, MakeOptimizations(FusePermuteIntoConstLayer())); + // Group Constant Layer optimizations together where possible. + // This is important as: + // FusePermuteIntoConstantLayer must happen before FoldPadIntoDepthwiseConvolution2d and + // FuseBatchNormIntoDepthwiseConvolution2D. + // ConvertConstDequantisationLayersToConstLayers must happen before FoldPadIntoConvolution2d + Optimizer::Pass(optGraph, MakeOptimizations(FusePermuteIntoConstLayer(), + ConvertConstDequantisationLayersToConstLayers())); // Perform optimisation passes Optimizer::Pass(optGraph, MakeOptimizations(SquashEqualPermuteSiblings(), SquashEqualTransposeSiblings(), @@ -1659,8 +1663,7 @@ IOptimizedNetworkPtr Optimize(const Graph& inGraph, FuseBatchNormIntoConvolution2DFloat32(), FuseBatchNormIntoConvolution2DFloat16(), FuseBatchNormIntoDepthwiseConvolution2DFloat32(), - FuseBatchNormIntoDepthwiseConvolution2DFloat16(), - ConvertConstDequantisationLayersToConstLayers())); + FuseBatchNormIntoDepthwiseConvolution2DFloat16())); // If Fp32 to Fp16 optimization is set convert Fp32 network to Fp16 if (options.m_ReduceFp32ToFp16) |