diff options
Diffstat (limited to 'src/armnn/Network.cpp')
-rw-r--r-- | src/armnn/Network.cpp | 29 |
1 files changed, 24 insertions, 5 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 77ad5c4dc2..6a646d3cc8 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -1854,16 +1854,35 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork, bool NetworkImpl::GetShapeInferenceMethod() { - if (m_NetworkOptions.size() > 0 && m_NetworkOptions[0].GetBackendId().Get() == "ShapeInferenceMethod") + bool shapeInferenceMethod = false; + + ParseOptions(m_NetworkOptions, "ShapeInferenceMethod", [&](std::string name, const BackendOptions::Var& value) { - return m_NetworkOptions[0].GetOption(0).GetValue().AsBool(); - } + if (name == "InferAndValidate") + { + shapeInferenceMethod |= value.AsBool(); + } + }); + return shapeInferenceMethod; +} - return false; +bool NetworkImpl::GetAllowExpandedDims() +{ + bool allowExpandedDims = false; + + ParseOptions(m_NetworkOptions, "AllowExpandedDims", [&](std::string name, const BackendOptions::Var& value) + { + if (name == "AllowExpandedDims") + { + allowExpandedDims |= value.AsBool(); + } + }); + return allowExpandedDims; } + NetworkImpl::NetworkImpl(NetworkOptions networkOptions) : m_NetworkOptions(networkOptions), - m_Graph(std::make_unique<Graph>(GetShapeInferenceMethod())) + m_Graph(std::make_unique<Graph>(GetShapeInferenceMethod(), GetAllowExpandedDims())) {} NetworkImpl::~NetworkImpl() |