aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Network.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Network.cpp')
-rw-r--r--src/armnn/Network.cpp21
1 files changed, 21 insertions, 0 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index 43782e0982..7b430c3ac5 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -11,6 +11,8 @@
#include <backendsCommon/CpuTensorHandle.hpp>
#include <backendsCommon/WorkloadFactory.hpp>
+#include <backendsCommon/BackendRegistry.hpp>
+#include <backendsCommon/IBackendInternal.hpp>
#include <armnn/Exceptions.hpp>
#include <armnn/Utils.hpp>
@@ -169,6 +171,9 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
return IOptimizedNetworkPtr(nullptr, &IOptimizedNetwork::Destroy);
};
+ // The backends that we choose to run layers on
+ std::unordered_set<BackendId> chosenBackends;
+
// Assign a compute device for all nodes
bool bErrorFound = false;
for (auto&& layer : optNetObjPtr->GetGraph())
@@ -275,6 +280,7 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
else
{
found = true;
+ chosenBackends.insert(backend);
break;
}
}
@@ -291,6 +297,7 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
layerType == armnn::LayerType::Permute))
{
layer->SetBackendId(armnn::Compute::CpuRef);
+ chosenBackends.insert(armnn::Compute::CpuRef);
}
else
{
@@ -312,6 +319,20 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
Optimizer::Pass(optNetObjPtr->GetGraph(), MakeOptimizations(ConvertConstantsFloatToHalf()));
Optimizer::Pass(optNetObjPtr->GetGraph(), MakeOptimizations(ConvertConstantsHalfToFloat()));
+ // Run backend specific optimizations
+ for (auto&& chosenBackend : chosenBackends)
+ {
+ auto factoryFun = BackendRegistryInstance().GetFactory(chosenBackend);
+ auto backendPtr = factoryFun();
+ BOOST_ASSERT(backendPtr.get() != nullptr);
+
+ auto backendSpecificOptimizations = backendPtr->GetOptimizations();
+ if (!backendSpecificOptimizations.empty())
+ {
+ Optimizer::Pass(optNetObjPtr->GetGraph(), backendSpecificOptimizations);
+ }
+ }
+
return optNet;
}