diff options
-rw-r--r-- | src/armnn/Network.cpp | 21 | ||||
-rw-r--r-- | src/backends/backendsCommon/IBackendInternal.hpp | 5 | ||||
-rw-r--r-- | src/backends/cl/ClBackend.cpp | 6 | ||||
-rw-r--r-- | src/backends/cl/ClBackend.hpp | 2 | ||||
-rw-r--r-- | src/backends/neon/NeonBackend.cpp | 12 | ||||
-rw-r--r-- | src/backends/neon/NeonBackend.hpp | 8 | ||||
-rw-r--r-- | src/backends/reference/RefBackend.cpp | 12 | ||||
-rw-r--r-- | src/backends/reference/RefBackend.hpp | 8 |
8 files changed, 61 insertions, 13 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 43782e0982..7b430c3ac5 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -11,6 +11,8 @@ #include <backendsCommon/CpuTensorHandle.hpp> #include <backendsCommon/WorkloadFactory.hpp> +#include <backendsCommon/BackendRegistry.hpp> +#include <backendsCommon/IBackendInternal.hpp> #include <armnn/Exceptions.hpp> #include <armnn/Utils.hpp> @@ -169,6 +171,9 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork, return IOptimizedNetworkPtr(nullptr, &IOptimizedNetwork::Destroy); }; + // The backends that we choose to run layers on + std::unordered_set<BackendId> chosenBackends; + // Assign a compute device for all nodes bool bErrorFound = false; for (auto&& layer : optNetObjPtr->GetGraph()) @@ -275,6 +280,7 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork, else { found = true; + chosenBackends.insert(backend); break; } } @@ -291,6 +297,7 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork, layerType == armnn::LayerType::Permute)) { layer->SetBackendId(armnn::Compute::CpuRef); + chosenBackends.insert(armnn::Compute::CpuRef); } else { @@ -312,6 +319,20 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork, Optimizer::Pass(optNetObjPtr->GetGraph(), MakeOptimizations(ConvertConstantsFloatToHalf())); Optimizer::Pass(optNetObjPtr->GetGraph(), MakeOptimizations(ConvertConstantsHalfToFloat())); + // Run backend specific optimizations + for (auto&& chosenBackend : chosenBackends) + { + auto factoryFun = BackendRegistryInstance().GetFactory(chosenBackend); + auto backendPtr = factoryFun(); + BOOST_ASSERT(backendPtr.get() != nullptr); + + auto backendSpecificOptimizations = backendPtr->GetOptimizations(); + if (!backendSpecificOptimizations.empty()) + { + Optimizer::Pass(optNetObjPtr->GetGraph(), backendSpecificOptimizations); + } + } + return optNet; } diff --git a/src/backends/backendsCommon/IBackendInternal.hpp b/src/backends/backendsCommon/IBackendInternal.hpp index fede366475..9c54b821e7 100644 --- a/src/backends/backendsCommon/IBackendInternal.hpp +++ b/src/backends/backendsCommon/IBackendInternal.hpp @@ -6,11 +6,13 @@ #include <armnn/Types.hpp> #include <armnn/IRuntime.hpp> +#include <vector> namespace armnn { class IWorkloadFactory; class IBackendContext; +class Optimization; class IBackendInternal : public IBackend { @@ -26,9 +28,12 @@ public: using IWorkloadFactoryPtr = std::unique_ptr<IWorkloadFactory>; using IBackendContextPtr = std::unique_ptr<IBackendContext>; + using OptimizationPtr = std::unique_ptr<Optimization>; + using Optimizations = std::vector<OptimizationPtr>; virtual IWorkloadFactoryPtr CreateWorkloadFactory() const = 0; virtual IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const = 0; + virtual Optimizations GetOptimizations() const = 0; }; using IBackendInternalUniquePtr = std::unique_ptr<IBackendInternal>; diff --git a/src/backends/cl/ClBackend.cpp b/src/backends/cl/ClBackend.cpp index c07fa66457..8209a109a4 100644 --- a/src/backends/cl/ClBackend.cpp +++ b/src/backends/cl/ClBackend.cpp @@ -8,7 +8,9 @@ #include "ClWorkloadFactory.hpp" #include "ClBackendContext.hpp" +#include <backendsCommon/IBackendContext.hpp> #include <backendsCommon/BackendRegistry.hpp> +#include <Optimizer.hpp> namespace armnn { @@ -45,5 +47,9 @@ ClBackend::CreateBackendContext(const IRuntime::CreationOptions& options) const return IBackendContextPtr{new ClBackendContext{options}}; } +IBackendInternal::Optimizations ClBackend::GetOptimizations() const +{ + return Optimizations{}; +} } // namespace armnn diff --git a/src/backends/cl/ClBackend.hpp b/src/backends/cl/ClBackend.hpp index f8a6253c22..ad84e8a159 100644 --- a/src/backends/cl/ClBackend.hpp +++ b/src/backends/cl/ClBackend.hpp @@ -4,7 +4,6 @@ // #pragma once -#include <backendsCommon/IBackendContext.hpp> #include <backendsCommon/IBackendInternal.hpp> namespace armnn @@ -21,6 +20,7 @@ public: IBackendInternal::IWorkloadFactoryPtr CreateWorkloadFactory() const override; IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override; + IBackendInternal::Optimizations GetOptimizations() const override; }; } // namespace armnn
\ No newline at end of file diff --git a/src/backends/neon/NeonBackend.cpp b/src/backends/neon/NeonBackend.cpp index 7058d24e72..9e079f38ce 100644 --- a/src/backends/neon/NeonBackend.cpp +++ b/src/backends/neon/NeonBackend.cpp @@ -7,7 +7,9 @@ #include "NeonBackendId.hpp" #include "NeonWorkloadFactory.hpp" +#include <backendsCommon/IBackendContext.hpp> #include <backendsCommon/BackendRegistry.hpp> +#include <Optimizer.hpp> #include <boost/cast.hpp> @@ -40,4 +42,14 @@ IBackendInternal::IWorkloadFactoryPtr NeonBackend::CreateWorkloadFactory() const return std::make_unique<NeonWorkloadFactory>(); } +IBackendInternal::IBackendContextPtr NeonBackend::CreateBackendContext(const IRuntime::CreationOptions&) const +{ + return IBackendContextPtr{}; +} + +IBackendInternal::Optimizations NeonBackend::GetOptimizations() const +{ + return Optimizations{}; +} + } // namespace armnn
\ No newline at end of file diff --git a/src/backends/neon/NeonBackend.hpp b/src/backends/neon/NeonBackend.hpp index 9ee8b238b3..e0017d92c8 100644 --- a/src/backends/neon/NeonBackend.hpp +++ b/src/backends/neon/NeonBackend.hpp @@ -4,7 +4,6 @@ // #pragma once -#include <backendsCommon/IBackendContext.hpp> #include <backendsCommon/IBackendInternal.hpp> namespace armnn @@ -20,11 +19,8 @@ public: const BackendId& GetId() const override { return GetIdStatic(); } IWorkloadFactoryPtr CreateWorkloadFactory() const override; - - IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override - { - return IBackendContextPtr{}; - } + IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override; + IBackendInternal::Optimizations GetOptimizations() const override; }; } // namespace armnn
\ No newline at end of file diff --git a/src/backends/reference/RefBackend.cpp b/src/backends/reference/RefBackend.cpp index b6fb0ff5ec..2f5ec8032c 100644 --- a/src/backends/reference/RefBackend.cpp +++ b/src/backends/reference/RefBackend.cpp @@ -7,7 +7,9 @@ #include "RefBackendId.hpp" #include "RefWorkloadFactory.hpp" +#include <backendsCommon/IBackendContext.hpp> #include <backendsCommon/BackendRegistry.hpp> +#include <Optimizer.hpp> #include <boost/cast.hpp> @@ -40,4 +42,14 @@ IBackendInternal::IWorkloadFactoryPtr RefBackend::CreateWorkloadFactory() const return std::make_unique<RefWorkloadFactory>(); } +IBackendInternal::IBackendContextPtr RefBackend::CreateBackendContext(const IRuntime::CreationOptions&) const +{ + return IBackendContextPtr{}; +} + +IBackendInternal::Optimizations RefBackend::GetOptimizations() const +{ + return Optimizations{}; +} + } // namespace armnn
\ No newline at end of file diff --git a/src/backends/reference/RefBackend.hpp b/src/backends/reference/RefBackend.hpp index 025a4826b2..be71f356f3 100644 --- a/src/backends/reference/RefBackend.hpp +++ b/src/backends/reference/RefBackend.hpp @@ -4,7 +4,6 @@ // #pragma once -#include <backendsCommon/IBackendContext.hpp> #include <backendsCommon/IBackendInternal.hpp> namespace armnn @@ -20,11 +19,8 @@ public: const BackendId& GetId() const override { return GetIdStatic(); } IBackendInternal::IWorkloadFactoryPtr CreateWorkloadFactory() const override; - - IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override - { - return IBackendContextPtr{}; - } + IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override; + IBackendInternal::Optimizations GetOptimizations() const override; }; } // namespace armnn
\ No newline at end of file |