aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/armnn/Network.cpp13
-rw-r--r--tests/InferenceModel.hpp3
2 files changed, 14 insertions, 2 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index 365f1bdfa1..42d7ae33ac 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -1594,13 +1594,22 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
// Get the optimized graph
Graph& optGraph = optNetObjPtr->pOptimizedNetworkImpl->GetGraph();
- // Infer the tensor infos for all output slots. Throws an exception on failure
- optGraph.InferTensorInfos();
+ if(options.m_shapeInferenceMethod == ShapeInferenceMethod::InferAndValidate)
+ {
+ // Infer the tensor infos for all output slots. Throws an exception on failure
+ optGraph.InferTensorInfos();
+ }
// Perform AddBroadcastReshapeLayer optimisation
using namespace optimizations;
Optimizer::Pass(optGraph, MakeOptimizations(AddBroadcastReshapeLayer()));
+ if(options.m_shapeInferenceMethod == ShapeInferenceMethod::ValidateOnly)
+ {
+ // Validate the tensor infos for all output slots. Throws an exception on failure
+ optGraph.InferTensorInfos();
+ }
+
// Perform optimisation passes
Optimizer::Pass(optGraph, MakeOptimizations(SquashEqualPermuteSiblings(),
SquashEqualTransposeSiblings(),
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp
index 31075939ce..4d2b167522 100644
--- a/tests/InferenceModel.hpp
+++ b/tests/InferenceModel.hpp
@@ -440,6 +440,9 @@ public:
options.m_ReduceFp32ToBf16 = params.m_EnableBf16TurboMode;
options.m_Debug = params.m_PrintIntermediateLayers;
+ options.m_shapeInferenceMethod = params.m_InferOutputShape ?
+ armnn::ShapeInferenceMethod::InferAndValidate : armnn::ShapeInferenceMethod::ValidateOnly;
+
armnn::BackendOptions gpuAcc("GpuAcc",
{
{ "FastMathEnabled", params.m_EnableFastMath },