diff options
Diffstat (limited to 'src/armnn/Network.cpp')
-rw-r--r-- | src/armnn/Network.cpp | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 43782e0982..7b430c3ac5 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -11,6 +11,8 @@ #include <backendsCommon/CpuTensorHandle.hpp> #include <backendsCommon/WorkloadFactory.hpp> +#include <backendsCommon/BackendRegistry.hpp> +#include <backendsCommon/IBackendInternal.hpp> #include <armnn/Exceptions.hpp> #include <armnn/Utils.hpp> @@ -169,6 +171,9 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork, return IOptimizedNetworkPtr(nullptr, &IOptimizedNetwork::Destroy); }; + // The backends that we choose to run layers on + std::unordered_set<BackendId> chosenBackends; + // Assign a compute device for all nodes bool bErrorFound = false; for (auto&& layer : optNetObjPtr->GetGraph()) @@ -275,6 +280,7 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork, else { found = true; + chosenBackends.insert(backend); break; } } @@ -291,6 +297,7 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork, layerType == armnn::LayerType::Permute)) { layer->SetBackendId(armnn::Compute::CpuRef); + chosenBackends.insert(armnn::Compute::CpuRef); } else { @@ -312,6 +319,20 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork, Optimizer::Pass(optNetObjPtr->GetGraph(), MakeOptimizations(ConvertConstantsFloatToHalf())); Optimizer::Pass(optNetObjPtr->GetGraph(), MakeOptimizations(ConvertConstantsHalfToFloat())); + // Run backend specific optimizations + for (auto&& chosenBackend : chosenBackends) + { + auto factoryFun = BackendRegistryInstance().GetFactory(chosenBackend); + auto backendPtr = factoryFun(); + BOOST_ASSERT(backendPtr.get() != nullptr); + + auto backendSpecificOptimizations = backendPtr->GetOptimizations(); + if (!backendSpecificOptimizations.empty()) + { + Optimizer::Pass(optNetObjPtr->GetGraph(), backendSpecificOptimizations); + } + } + return optNet; } |