From 263e34988abe54d79133850182190661bfd977df Mon Sep 17 00:00:00 2001 From: David Beck Date: Fri, 9 Nov 2018 14:46:40 +0000 Subject: IVGCVSW-2125 : backends now can return optimizations Change-Id: Ieec34224b433e1d2f3bbe66632cd6016cac5498c --- src/armnn/Network.cpp | 21 +++++++++++++++++++++ src/backends/backendsCommon/IBackendInternal.hpp | 5 +++++ src/backends/cl/ClBackend.cpp | 6 ++++++ src/backends/cl/ClBackend.hpp | 2 +- src/backends/neon/NeonBackend.cpp | 12 ++++++++++++ src/backends/neon/NeonBackend.hpp | 8 ++------ src/backends/reference/RefBackend.cpp | 12 ++++++++++++ src/backends/reference/RefBackend.hpp | 8 ++------ 8 files changed, 61 insertions(+), 13 deletions(-) diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 43782e0982..7b430c3ac5 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -11,6 +11,8 @@ #include #include +#include +#include #include #include @@ -169,6 +171,9 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork, return IOptimizedNetworkPtr(nullptr, &IOptimizedNetwork::Destroy); }; + // The backends that we choose to run layers on + std::unordered_set chosenBackends; + // Assign a compute device for all nodes bool bErrorFound = false; for (auto&& layer : optNetObjPtr->GetGraph()) @@ -275,6 +280,7 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork, else { found = true; + chosenBackends.insert(backend); break; } } @@ -291,6 +297,7 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork, layerType == armnn::LayerType::Permute)) { layer->SetBackendId(armnn::Compute::CpuRef); + chosenBackends.insert(armnn::Compute::CpuRef); } else { @@ -312,6 +319,20 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork, Optimizer::Pass(optNetObjPtr->GetGraph(), MakeOptimizations(ConvertConstantsFloatToHalf())); Optimizer::Pass(optNetObjPtr->GetGraph(), MakeOptimizations(ConvertConstantsHalfToFloat())); + // Run backend specific optimizations + for (auto&& chosenBackend : chosenBackends) + { + auto factoryFun = BackendRegistryInstance().GetFactory(chosenBackend); + auto backendPtr = factoryFun(); + BOOST_ASSERT(backendPtr.get() != nullptr); + + auto backendSpecificOptimizations = backendPtr->GetOptimizations(); + if (!backendSpecificOptimizations.empty()) + { + Optimizer::Pass(optNetObjPtr->GetGraph(), backendSpecificOptimizations); + } + } + return optNet; } diff --git a/src/backends/backendsCommon/IBackendInternal.hpp b/src/backends/backendsCommon/IBackendInternal.hpp index fede366475..9c54b821e7 100644 --- a/src/backends/backendsCommon/IBackendInternal.hpp +++ b/src/backends/backendsCommon/IBackendInternal.hpp @@ -6,11 +6,13 @@ #include #include +#include namespace armnn { class IWorkloadFactory; class IBackendContext; +class Optimization; class IBackendInternal : public IBackend { @@ -26,9 +28,12 @@ public: using IWorkloadFactoryPtr = std::unique_ptr; using IBackendContextPtr = std::unique_ptr; + using OptimizationPtr = std::unique_ptr; + using Optimizations = std::vector; virtual IWorkloadFactoryPtr CreateWorkloadFactory() const = 0; virtual IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const = 0; + virtual Optimizations GetOptimizations() const = 0; }; using IBackendInternalUniquePtr = std::unique_ptr; diff --git a/src/backends/cl/ClBackend.cpp b/src/backends/cl/ClBackend.cpp index c07fa66457..8209a109a4 100644 --- a/src/backends/cl/ClBackend.cpp +++ b/src/backends/cl/ClBackend.cpp @@ -8,7 +8,9 @@ #include "ClWorkloadFactory.hpp" #include "ClBackendContext.hpp" +#include #include +#include namespace armnn { @@ -45,5 +47,9 @@ ClBackend::CreateBackendContext(const IRuntime::CreationOptions& options) const return IBackendContextPtr{new ClBackendContext{options}}; } +IBackendInternal::Optimizations ClBackend::GetOptimizations() const +{ + return Optimizations{}; +} } // namespace armnn diff --git a/src/backends/cl/ClBackend.hpp b/src/backends/cl/ClBackend.hpp index f8a6253c22..ad84e8a159 100644 --- a/src/backends/cl/ClBackend.hpp +++ b/src/backends/cl/ClBackend.hpp @@ -4,7 +4,6 @@ // #pragma once -#include #include namespace armnn @@ -21,6 +20,7 @@ public: IBackendInternal::IWorkloadFactoryPtr CreateWorkloadFactory() const override; IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override; + IBackendInternal::Optimizations GetOptimizations() const override; }; } // namespace armnn \ No newline at end of file diff --git a/src/backends/neon/NeonBackend.cpp b/src/backends/neon/NeonBackend.cpp index 7058d24e72..9e079f38ce 100644 --- a/src/backends/neon/NeonBackend.cpp +++ b/src/backends/neon/NeonBackend.cpp @@ -7,7 +7,9 @@ #include "NeonBackendId.hpp" #include "NeonWorkloadFactory.hpp" +#include #include +#include #include @@ -40,4 +42,14 @@ IBackendInternal::IWorkloadFactoryPtr NeonBackend::CreateWorkloadFactory() const return std::make_unique(); } +IBackendInternal::IBackendContextPtr NeonBackend::CreateBackendContext(const IRuntime::CreationOptions&) const +{ + return IBackendContextPtr{}; +} + +IBackendInternal::Optimizations NeonBackend::GetOptimizations() const +{ + return Optimizations{}; +} + } // namespace armnn \ No newline at end of file diff --git a/src/backends/neon/NeonBackend.hpp b/src/backends/neon/NeonBackend.hpp index 9ee8b238b3..e0017d92c8 100644 --- a/src/backends/neon/NeonBackend.hpp +++ b/src/backends/neon/NeonBackend.hpp @@ -4,7 +4,6 @@ // #pragma once -#include #include namespace armnn @@ -20,11 +19,8 @@ public: const BackendId& GetId() const override { return GetIdStatic(); } IWorkloadFactoryPtr CreateWorkloadFactory() const override; - - IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override - { - return IBackendContextPtr{}; - } + IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override; + IBackendInternal::Optimizations GetOptimizations() const override; }; } // namespace armnn \ No newline at end of file diff --git a/src/backends/reference/RefBackend.cpp b/src/backends/reference/RefBackend.cpp index b6fb0ff5ec..2f5ec8032c 100644 --- a/src/backends/reference/RefBackend.cpp +++ b/src/backends/reference/RefBackend.cpp @@ -7,7 +7,9 @@ #include "RefBackendId.hpp" #include "RefWorkloadFactory.hpp" +#include #include +#include #include @@ -40,4 +42,14 @@ IBackendInternal::IWorkloadFactoryPtr RefBackend::CreateWorkloadFactory() const return std::make_unique(); } +IBackendInternal::IBackendContextPtr RefBackend::CreateBackendContext(const IRuntime::CreationOptions&) const +{ + return IBackendContextPtr{}; +} + +IBackendInternal::Optimizations RefBackend::GetOptimizations() const +{ + return Optimizations{}; +} + } // namespace armnn \ No newline at end of file diff --git a/src/backends/reference/RefBackend.hpp b/src/backends/reference/RefBackend.hpp index 025a4826b2..be71f356f3 100644 --- a/src/backends/reference/RefBackend.hpp +++ b/src/backends/reference/RefBackend.hpp @@ -4,7 +4,6 @@ // #pragma once -#include #include namespace armnn @@ -20,11 +19,8 @@ public: const BackendId& GetId() const override { return GetIdStatic(); } IBackendInternal::IWorkloadFactoryPtr CreateWorkloadFactory() const override; - - IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override - { - return IBackendContextPtr{}; - } + IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override; + IBackendInternal::Optimizations GetOptimizations() const override; }; } // namespace armnn \ No newline at end of file -- cgit v1.2.1