aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Beck <david.beck@arm.com>2018-11-09 14:46:40 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2018-11-09 16:44:01 +0000
commit263e34988abe54d79133850182190661bfd977df (patch)
tree0cfcee4d839db25eafdb6f7e2a7e19131d289bfa
parent3ea76d5f0d99794cf5f0b60ef3738d0905f10b2a (diff)
downloadarmnn-263e34988abe54d79133850182190661bfd977df.tar.gz
IVGCVSW-2125 : backends now can return optimizations
Change-Id: Ieec34224b433e1d2f3bbe66632cd6016cac5498c
-rw-r--r--src/armnn/Network.cpp21
-rw-r--r--src/backends/backendsCommon/IBackendInternal.hpp5
-rw-r--r--src/backends/cl/ClBackend.cpp6
-rw-r--r--src/backends/cl/ClBackend.hpp2
-rw-r--r--src/backends/neon/NeonBackend.cpp12
-rw-r--r--src/backends/neon/NeonBackend.hpp8
-rw-r--r--src/backends/reference/RefBackend.cpp12
-rw-r--r--src/backends/reference/RefBackend.hpp8
8 files changed, 61 insertions, 13 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index 43782e0982..7b430c3ac5 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -11,6 +11,8 @@
#include <backendsCommon/CpuTensorHandle.hpp>
#include <backendsCommon/WorkloadFactory.hpp>
+#include <backendsCommon/BackendRegistry.hpp>
+#include <backendsCommon/IBackendInternal.hpp>
#include <armnn/Exceptions.hpp>
#include <armnn/Utils.hpp>
@@ -169,6 +171,9 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
return IOptimizedNetworkPtr(nullptr, &IOptimizedNetwork::Destroy);
};
+ // The backends that we choose to run layers on
+ std::unordered_set<BackendId> chosenBackends;
+
// Assign a compute device for all nodes
bool bErrorFound = false;
for (auto&& layer : optNetObjPtr->GetGraph())
@@ -275,6 +280,7 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
else
{
found = true;
+ chosenBackends.insert(backend);
break;
}
}
@@ -291,6 +297,7 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
layerType == armnn::LayerType::Permute))
{
layer->SetBackendId(armnn::Compute::CpuRef);
+ chosenBackends.insert(armnn::Compute::CpuRef);
}
else
{
@@ -312,6 +319,20 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
Optimizer::Pass(optNetObjPtr->GetGraph(), MakeOptimizations(ConvertConstantsFloatToHalf()));
Optimizer::Pass(optNetObjPtr->GetGraph(), MakeOptimizations(ConvertConstantsHalfToFloat()));
+ // Run backend specific optimizations
+ for (auto&& chosenBackend : chosenBackends)
+ {
+ auto factoryFun = BackendRegistryInstance().GetFactory(chosenBackend);
+ auto backendPtr = factoryFun();
+ BOOST_ASSERT(backendPtr.get() != nullptr);
+
+ auto backendSpecificOptimizations = backendPtr->GetOptimizations();
+ if (!backendSpecificOptimizations.empty())
+ {
+ Optimizer::Pass(optNetObjPtr->GetGraph(), backendSpecificOptimizations);
+ }
+ }
+
return optNet;
}
diff --git a/src/backends/backendsCommon/IBackendInternal.hpp b/src/backends/backendsCommon/IBackendInternal.hpp
index fede366475..9c54b821e7 100644
--- a/src/backends/backendsCommon/IBackendInternal.hpp
+++ b/src/backends/backendsCommon/IBackendInternal.hpp
@@ -6,11 +6,13 @@
#include <armnn/Types.hpp>
#include <armnn/IRuntime.hpp>
+#include <vector>
namespace armnn
{
class IWorkloadFactory;
class IBackendContext;
+class Optimization;
class IBackendInternal : public IBackend
{
@@ -26,9 +28,12 @@ public:
using IWorkloadFactoryPtr = std::unique_ptr<IWorkloadFactory>;
using IBackendContextPtr = std::unique_ptr<IBackendContext>;
+ using OptimizationPtr = std::unique_ptr<Optimization>;
+ using Optimizations = std::vector<OptimizationPtr>;
virtual IWorkloadFactoryPtr CreateWorkloadFactory() const = 0;
virtual IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const = 0;
+ virtual Optimizations GetOptimizations() const = 0;
};
using IBackendInternalUniquePtr = std::unique_ptr<IBackendInternal>;
diff --git a/src/backends/cl/ClBackend.cpp b/src/backends/cl/ClBackend.cpp
index c07fa66457..8209a109a4 100644
--- a/src/backends/cl/ClBackend.cpp
+++ b/src/backends/cl/ClBackend.cpp
@@ -8,7 +8,9 @@
#include "ClWorkloadFactory.hpp"
#include "ClBackendContext.hpp"
+#include <backendsCommon/IBackendContext.hpp>
#include <backendsCommon/BackendRegistry.hpp>
+#include <Optimizer.hpp>
namespace armnn
{
@@ -45,5 +47,9 @@ ClBackend::CreateBackendContext(const IRuntime::CreationOptions& options) const
return IBackendContextPtr{new ClBackendContext{options}};
}
+IBackendInternal::Optimizations ClBackend::GetOptimizations() const
+{
+ return Optimizations{};
+}
} // namespace armnn
diff --git a/src/backends/cl/ClBackend.hpp b/src/backends/cl/ClBackend.hpp
index f8a6253c22..ad84e8a159 100644
--- a/src/backends/cl/ClBackend.hpp
+++ b/src/backends/cl/ClBackend.hpp
@@ -4,7 +4,6 @@
//
#pragma once
-#include <backendsCommon/IBackendContext.hpp>
#include <backendsCommon/IBackendInternal.hpp>
namespace armnn
@@ -21,6 +20,7 @@ public:
IBackendInternal::IWorkloadFactoryPtr CreateWorkloadFactory() const override;
IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override;
+ IBackendInternal::Optimizations GetOptimizations() const override;
};
} // namespace armnn \ No newline at end of file
diff --git a/src/backends/neon/NeonBackend.cpp b/src/backends/neon/NeonBackend.cpp
index 7058d24e72..9e079f38ce 100644
--- a/src/backends/neon/NeonBackend.cpp
+++ b/src/backends/neon/NeonBackend.cpp
@@ -7,7 +7,9 @@
#include "NeonBackendId.hpp"
#include "NeonWorkloadFactory.hpp"
+#include <backendsCommon/IBackendContext.hpp>
#include <backendsCommon/BackendRegistry.hpp>
+#include <Optimizer.hpp>
#include <boost/cast.hpp>
@@ -40,4 +42,14 @@ IBackendInternal::IWorkloadFactoryPtr NeonBackend::CreateWorkloadFactory() const
return std::make_unique<NeonWorkloadFactory>();
}
+IBackendInternal::IBackendContextPtr NeonBackend::CreateBackendContext(const IRuntime::CreationOptions&) const
+{
+ return IBackendContextPtr{};
+}
+
+IBackendInternal::Optimizations NeonBackend::GetOptimizations() const
+{
+ return Optimizations{};
+}
+
} // namespace armnn \ No newline at end of file
diff --git a/src/backends/neon/NeonBackend.hpp b/src/backends/neon/NeonBackend.hpp
index 9ee8b238b3..e0017d92c8 100644
--- a/src/backends/neon/NeonBackend.hpp
+++ b/src/backends/neon/NeonBackend.hpp
@@ -4,7 +4,6 @@
//
#pragma once
-#include <backendsCommon/IBackendContext.hpp>
#include <backendsCommon/IBackendInternal.hpp>
namespace armnn
@@ -20,11 +19,8 @@ public:
const BackendId& GetId() const override { return GetIdStatic(); }
IWorkloadFactoryPtr CreateWorkloadFactory() const override;
-
- IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override
- {
- return IBackendContextPtr{};
- }
+ IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override;
+ IBackendInternal::Optimizations GetOptimizations() const override;
};
} // namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/RefBackend.cpp b/src/backends/reference/RefBackend.cpp
index b6fb0ff5ec..2f5ec8032c 100644
--- a/src/backends/reference/RefBackend.cpp
+++ b/src/backends/reference/RefBackend.cpp
@@ -7,7 +7,9 @@
#include "RefBackendId.hpp"
#include "RefWorkloadFactory.hpp"
+#include <backendsCommon/IBackendContext.hpp>
#include <backendsCommon/BackendRegistry.hpp>
+#include <Optimizer.hpp>
#include <boost/cast.hpp>
@@ -40,4 +42,14 @@ IBackendInternal::IWorkloadFactoryPtr RefBackend::CreateWorkloadFactory() const
return std::make_unique<RefWorkloadFactory>();
}
+IBackendInternal::IBackendContextPtr RefBackend::CreateBackendContext(const IRuntime::CreationOptions&) const
+{
+ return IBackendContextPtr{};
+}
+
+IBackendInternal::Optimizations RefBackend::GetOptimizations() const
+{
+ return Optimizations{};
+}
+
} // namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/RefBackend.hpp b/src/backends/reference/RefBackend.hpp
index 025a4826b2..be71f356f3 100644
--- a/src/backends/reference/RefBackend.hpp
+++ b/src/backends/reference/RefBackend.hpp
@@ -4,7 +4,6 @@
//
#pragma once
-#include <backendsCommon/IBackendContext.hpp>
#include <backendsCommon/IBackendInternal.hpp>
namespace armnn
@@ -20,11 +19,8 @@ public:
const BackendId& GetId() const override { return GetIdStatic(); }
IBackendInternal::IWorkloadFactoryPtr CreateWorkloadFactory() const override;
-
- IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override
- {
- return IBackendContextPtr{};
- }
+ IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override;
+ IBackendInternal::Optimizations GetOptimizations() const override;
};
} // namespace armnn \ No newline at end of file