diff options
author | Finn Williams <Finn.Williams@arm.com> | 2021-08-09 13:00:08 +0100 |
---|---|---|
committer | Matthew Sloyan <matthew.sloyan@arm.com> | 2021-08-10 10:30:28 +0000 |
commit | d218d9804723e78da9bbd36e6211b3310426852b (patch) | |
tree | 38d4fc63d6a3ed376a094d8b4867e8c15f7d73ad /src/armnn/Network.cpp | |
parent | b20d1d4888c270d4d57a0bdcc011ded89a2f5b38 (diff) | |
download | armnn-d218d9804723e78da9bbd36e6211b3310426852b.tar.gz |
IVGCVSW-6289 Separate tensor shape inference and validation calls
* Pass m_shapeInferenceMethod to OptimizerOptions in ExecuteNetwork
Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Change-Id: I90280fb7629092d3b66e8a3968ca9e35a0df854a
Diffstat (limited to 'src/armnn/Network.cpp')
-rw-r--r-- | src/armnn/Network.cpp | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 365f1bdfa1..42d7ae33ac 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -1594,13 +1594,22 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork, // Get the optimized graph Graph& optGraph = optNetObjPtr->pOptimizedNetworkImpl->GetGraph(); - // Infer the tensor infos for all output slots. Throws an exception on failure - optGraph.InferTensorInfos(); + if(options.m_shapeInferenceMethod == ShapeInferenceMethod::InferAndValidate) + { + // Infer the tensor infos for all output slots. Throws an exception on failure + optGraph.InferTensorInfos(); + } // Perform AddBroadcastReshapeLayer optimisation using namespace optimizations; Optimizer::Pass(optGraph, MakeOptimizations(AddBroadcastReshapeLayer())); + if(options.m_shapeInferenceMethod == ShapeInferenceMethod::ValidateOnly) + { + // Validate the tensor infos for all output slots. Throws an exception on failure + optGraph.InferTensorInfos(); + } + // Perform optimisation passes Optimizer::Pass(optGraph, MakeOptimizations(SquashEqualPermuteSiblings(), SquashEqualTransposeSiblings(), |