From d218d9804723e78da9bbd36e6211b3310426852b Mon Sep 17 00:00:00 2001 From: Finn Williams Date: Mon, 9 Aug 2021 13:00:08 +0100 Subject: IVGCVSW-6289 Separate tensor shape inference and validation calls * Pass m_shapeInferenceMethod to OptimizerOptions in ExecuteNetwork Signed-off-by: Finn Williams Change-Id: I90280fb7629092d3b66e8a3968ca9e35a0df854a --- src/armnn/Network.cpp | 13 +++++++++++-- tests/InferenceModel.hpp | 3 +++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 365f1bdfa1..42d7ae33ac 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -1594,13 +1594,22 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork, // Get the optimized graph Graph& optGraph = optNetObjPtr->pOptimizedNetworkImpl->GetGraph(); - // Infer the tensor infos for all output slots. Throws an exception on failure - optGraph.InferTensorInfos(); + if(options.m_shapeInferenceMethod == ShapeInferenceMethod::InferAndValidate) + { + // Infer the tensor infos for all output slots. Throws an exception on failure + optGraph.InferTensorInfos(); + } // Perform AddBroadcastReshapeLayer optimisation using namespace optimizations; Optimizer::Pass(optGraph, MakeOptimizations(AddBroadcastReshapeLayer())); + if(options.m_shapeInferenceMethod == ShapeInferenceMethod::ValidateOnly) + { + // Validate the tensor infos for all output slots. Throws an exception on failure + optGraph.InferTensorInfos(); + } + // Perform optimisation passes Optimizer::Pass(optGraph, MakeOptimizations(SquashEqualPermuteSiblings(), SquashEqualTransposeSiblings(), diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp index 31075939ce..4d2b167522 100644 --- a/tests/InferenceModel.hpp +++ b/tests/InferenceModel.hpp @@ -440,6 +440,9 @@ public: options.m_ReduceFp32ToBf16 = params.m_EnableBf16TurboMode; options.m_Debug = params.m_PrintIntermediateLayers; + options.m_shapeInferenceMethod = params.m_InferOutputShape ? + armnn::ShapeInferenceMethod::InferAndValidate : armnn::ShapeInferenceMethod::ValidateOnly; + armnn::BackendOptions gpuAcc("GpuAcc", { { "FastMathEnabled", params.m_EnableFastMath }, -- cgit v1.2.1