aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Network.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Network.cpp')
-rw-r--r--src/armnn/Network.cpp22
1 files changed, 15 insertions, 7 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index e0607bda33..132924a19a 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -42,14 +42,14 @@
namespace armnn
{
-armnn::INetwork* INetwork::CreateRaw()
+armnn::INetwork* INetwork::CreateRaw(NetworkOptions networkOptions)
{
- return new Network();
+ return new Network(networkOptions);
}
-armnn::INetworkPtr INetwork::Create()
+armnn::INetworkPtr INetwork::Create(NetworkOptions networkOptions)
{
- return INetworkPtr(CreateRaw(), &INetwork::Destroy);
+ return INetworkPtr(CreateRaw(networkOptions), &INetwork::Destroy);
}
void INetwork::Destroy(INetwork* network)
@@ -1147,11 +1147,19 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
return optNet;
}
-
-Network::Network()
-: m_Graph(std::make_unique<Graph>())
+bool Network::GetShapeInferenceMethod()
{
+ if (m_NetworkOptions.size() > 0 && m_NetworkOptions[0].GetBackendId().Get() == "ShapeInferenceMethod")
+ {
+ return m_NetworkOptions[0].GetOption(0).GetValue().AsBool();
+ }
+
+ return false;
}
+Network::Network(NetworkOptions networkOptions)
+: m_NetworkOptions(networkOptions),
+ m_Graph(std::make_unique<Graph>(GetShapeInferenceMethod()))
+{}
Network::~Network()
{