aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2019-08-15 12:08:06 +0100
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-08-19 12:12:28 +0000
commit895339092fa9edc0aa59de0309f79bebacc3fa63 (patch)
tree6e98c570423ff4ff292ebd109e9c3792bdd1a3b9
parentf800de2140ca55f29bacfa6795df7a28aba3e5ff (diff)
downloadarmnn-895339092fa9edc0aa59de0309f79bebacc3fa63.tar.gz
IVGCVSW-3545 Update the device specs with the dynamic backend ids
* Now the utility function RegisterDynamicBackends returns a list of the backend ids that have been registered * The list of registered ids is added to the list of supported backends in the Runtime * Added unit tests Change-Id: I97bbe1f680920358f5baba5a4666e4983b849cac Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
-rw-r--r--src/armnn/DeviceSpec.hpp7
-rw-r--r--src/armnn/Runtime.cpp8
-rw-r--r--src/backends/backendsCommon/DynamicBackendUtils.cpp30
-rw-r--r--src/backends/backendsCommon/DynamicBackendUtils.hpp6
-rw-r--r--src/backends/backendsCommon/test/DynamicBackendTests.hpp65
5 files changed, 95 insertions, 21 deletions
diff --git a/src/armnn/DeviceSpec.hpp b/src/armnn/DeviceSpec.hpp
index 35923e6f9d..32264706fd 100644
--- a/src/armnn/DeviceSpec.hpp
+++ b/src/armnn/DeviceSpec.hpp
@@ -24,9 +24,14 @@ public:
return m_SupportedBackends;
}
+ void AddSupportedBackends(const BackendIdSet& backendIds)
+ {
+ m_SupportedBackends.insert(backendIds.begin(), backendIds.end());
+ }
+
private:
DeviceSpec() = delete;
BackendIdSet m_SupportedBackends;
};
-}
+} // namespace armnn
diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp
index 6b91863deb..9e874848ec 100644
--- a/src/armnn/Runtime.cpp
+++ b/src/armnn/Runtime.cpp
@@ -144,8 +144,7 @@ Runtime::Runtime(const CreationOptions& options)
{
// Store backend contexts for the supported ones
const BackendIdSet& supportedBackends = m_DeviceSpec.GetSupportedBackends();
- auto it = supportedBackends.find(id);
- if (it != supportedBackends.end())
+ if (supportedBackends.find(id) != supportedBackends.end())
{
auto factoryFun = BackendRegistryInstance().GetFactory(id);
auto backend = factoryFun();
@@ -257,7 +256,10 @@ void Runtime::LoadDynamicBackends(const std::string& overrideBackendPath)
m_DynamicBackends = DynamicBackendUtils::CreateDynamicBackends(sharedObjects);
// Register the dynamic backends in the backend registry
- DynamicBackendUtils::RegisterDynamicBackends(m_DynamicBackends);
+ BackendIdSet registeredBackendIds = DynamicBackendUtils::RegisterDynamicBackends(m_DynamicBackends);
+
+ // Add the registered dynamic backend ids to the list of supported backends
+ m_DeviceSpec.AddSupportedBackends(registeredBackendIds);
}
} // namespace armnn
diff --git a/src/backends/backendsCommon/DynamicBackendUtils.cpp b/src/backends/backendsCommon/DynamicBackendUtils.cpp
index fadec0c389..fc4336f4ac 100644
--- a/src/backends/backendsCommon/DynamicBackendUtils.cpp
+++ b/src/backends/backendsCommon/DynamicBackendUtils.cpp
@@ -299,21 +299,25 @@ std::vector<DynamicBackendPtr> DynamicBackendUtils::CreateDynamicBackends(const
return dynamicBackends;
}
-void DynamicBackendUtils::RegisterDynamicBackends(const std::vector<DynamicBackendPtr>& dynamicBackends)
+BackendIdSet DynamicBackendUtils::RegisterDynamicBackends(const std::vector<DynamicBackendPtr>& dynamicBackends)
{
// Get a reference of the backend registry
BackendRegistry& backendRegistry = BackendRegistryInstance();
- // Register the dynamic backends in the backend registry
- RegisterDynamicBackendsImpl(backendRegistry, dynamicBackends);
+ // Register the dynamic backends in the backend registry, and return a list of registered backend ids
+ return RegisterDynamicBackendsImpl(backendRegistry, dynamicBackends);
}
-void DynamicBackendUtils::RegisterDynamicBackendsImpl(BackendRegistry& backendRegistry,
- const std::vector<DynamicBackendPtr>& dynamicBackends)
+BackendIdSet DynamicBackendUtils::RegisterDynamicBackendsImpl(BackendRegistry& backendRegistry,
+ const std::vector<DynamicBackendPtr>& dynamicBackends)
{
+ // Initialize the list of registered backend ids
+ BackendIdSet registeredBackendIds;
+
// Register the dynamic backends in the backend registry
for (const DynamicBackendPtr& dynamicBackend : dynamicBackends)
{
+ // Get the id of the dynamic backend
BackendId dynamicBackendId;
try
{
@@ -362,8 +366,22 @@ void DynamicBackendUtils::RegisterDynamicBackendsImpl(BackendRegistry& backendRe
}
// Register the dynamic backend
- backendRegistry.Register(dynamicBackendId, dynamicBackendFactoryFunction);
+ try
+ {
+ backendRegistry.Register(dynamicBackendId, dynamicBackendFactoryFunction);
+ }
+ catch (const InvalidArgumentException& e)
+ {
+ BOOST_LOG_TRIVIAL(warning) << "An error has occurred when registering the dynamic backend \""
+ << dynamicBackendId << "\": " << e.what();
+ continue;
+ }
+
+ // Add the id of the dynamic backend just registered to the list of registered backend ids
+ registeredBackendIds.insert(dynamicBackendId);
}
+
+ return registeredBackendIds;
}
} // namespace armnn
diff --git a/src/backends/backendsCommon/DynamicBackendUtils.hpp b/src/backends/backendsCommon/DynamicBackendUtils.hpp
index 187b0b1eab..0aa0ac8da5 100644
--- a/src/backends/backendsCommon/DynamicBackendUtils.hpp
+++ b/src/backends/backendsCommon/DynamicBackendUtils.hpp
@@ -39,14 +39,14 @@ public:
static std::vector<std::string> GetSharedObjects(const std::vector<std::string>& backendPaths);
static std::vector<DynamicBackendPtr> CreateDynamicBackends(const std::vector<std::string>& sharedObjects);
- static void RegisterDynamicBackends(const std::vector<DynamicBackendPtr>& dynamicBackends);
+ static BackendIdSet RegisterDynamicBackends(const std::vector<DynamicBackendPtr>& dynamicBackends);
protected:
/// Protected methods for testing purposes
static bool IsBackendCompatibleImpl(const BackendVersion& backendApiVersion, const BackendVersion& backendVersion);
static std::vector<std::string> GetBackendPathsImpl(const std::string& backendPaths);
- static void RegisterDynamicBackendsImpl(BackendRegistry& backendRegistry,
- const std::vector<DynamicBackendPtr>& dynamicBackends);
+ static BackendIdSet RegisterDynamicBackendsImpl(BackendRegistry& backendRegistry,
+ const std::vector<DynamicBackendPtr>& dynamicBackends);
private:
static std::string GetDlError();
diff --git a/src/backends/backendsCommon/test/DynamicBackendTests.hpp b/src/backends/backendsCommon/test/DynamicBackendTests.hpp
index 74ef6f1ba7..e225124e01 100644
--- a/src/backends/backendsCommon/test/DynamicBackendTests.hpp
+++ b/src/backends/backendsCommon/test/DynamicBackendTests.hpp
@@ -79,10 +79,11 @@ public:
return GetBackendPathsImpl(path);
}
- static void RegisterDynamicBackendsImplTest(armnn::BackendRegistry& backendRegistry,
- const std::vector<armnn::DynamicBackendPtr>& dynamicBackends)
+ static armnn::BackendIdSet RegisterDynamicBackendsImplTest(
+ armnn::BackendRegistry& backendRegistry,
+ const std::vector<armnn::DynamicBackendPtr>& dynamicBackends)
{
- RegisterDynamicBackendsImpl(backendRegistry, dynamicBackends);
+ return RegisterDynamicBackendsImpl(backendRegistry, dynamicBackends);
}
};
@@ -896,12 +897,15 @@ void RegisterSingleDynamicBackendTestImpl()
BackendVersion dynamicBackendVersion = dynamicBackends[0]->GetBackendVersion();
BOOST_TEST(TestDynamicBackendUtils::IsBackendCompatible(dynamicBackendVersion));
- TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, dynamicBackends);
+ BackendIdSet registeredBackendIds = TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry,
+ dynamicBackends);
BOOST_TEST(backendRegistry.Size() == 1);
+ BOOST_TEST(registeredBackendIds.size() == 1);
BackendIdSet backendIds = backendRegistry.GetBackendIds();
BOOST_TEST(backendIds.size() == 1);
BOOST_TEST((backendIds.find(dynamicBackendId) != backendIds.end()));
+ BOOST_TEST((registeredBackendIds.find(dynamicBackendId) != registeredBackendIds.end()));
auto dynamicBackendFactoryFunction = backendRegistry.GetFactory(dynamicBackendId);
BOOST_TEST((dynamicBackendFactoryFunction != nullptr));
@@ -960,14 +964,19 @@ void RegisterMultipleDynamicBackendsTestImpl()
BackendRegistry backendRegistry;
BOOST_TEST(backendRegistry.Size() == 0);
- TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, dynamicBackends);
+ BackendIdSet registeredBackendIds = TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry,
+ dynamicBackends);
BOOST_TEST(backendRegistry.Size() == 3);
+ BOOST_TEST(registeredBackendIds.size() == 3);
BackendIdSet backendIds = backendRegistry.GetBackendIds();
BOOST_TEST(backendIds.size() == 3);
BOOST_TEST((backendIds.find(dynamicBackendId1) != backendIds.end()));
BOOST_TEST((backendIds.find(dynamicBackendId2) != backendIds.end()));
BOOST_TEST((backendIds.find(dynamicBackendId3) != backendIds.end()));
+ BOOST_TEST((registeredBackendIds.find(dynamicBackendId1) != registeredBackendIds.end()));
+ BOOST_TEST((registeredBackendIds.find(dynamicBackendId2) != registeredBackendIds.end()));
+ BOOST_TEST((registeredBackendIds.find(dynamicBackendId3) != registeredBackendIds.end()));
for (size_t i = 0; i < dynamicBackends.size(); i++)
{
@@ -1036,8 +1045,10 @@ void RegisterMultipleInvalidDynamicBackendsTestImpl()
BOOST_TEST(backendRegistry.Size() == 0);
// Check that no dynamic backend got registered
- TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, dynamicBackends);
+ BackendIdSet registeredBackendIds = TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry,
+ dynamicBackends);
BOOST_TEST(backendRegistry.Size() == 0);
+ BOOST_TEST(registeredBackendIds.empty());
}
void RegisterMixedDynamicBackendsTestImpl()
@@ -1165,14 +1176,17 @@ void RegisterMixedDynamicBackendsTestImpl()
"TestValid5"
};
- TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, dynamicBackends);
+ BackendIdSet registeredBackendIds = TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry,
+ dynamicBackends);
BOOST_TEST(backendRegistry.Size() == expectedRegisteredbackendIds.size());
+ BOOST_TEST(registeredBackendIds.size() == expectedRegisteredbackendIds.size());
BackendIdSet backendIds = backendRegistry.GetBackendIds();
BOOST_TEST(backendIds.size() == expectedRegisteredbackendIds.size());
for (const BackendId& expectedRegisteredbackendId : expectedRegisteredbackendIds)
{
BOOST_TEST((backendIds.find(expectedRegisteredbackendId) != backendIds.end()));
+ BOOST_TEST((registeredBackendIds.find(expectedRegisteredbackendId) != registeredBackendIds.end()));
auto dynamicBackendFactoryFunction = backendRegistry.GetFactory(expectedRegisteredbackendId);
BOOST_TEST((dynamicBackendFactoryFunction != nullptr));
@@ -1190,10 +1204,16 @@ void RuntimeEmptyTestImpl()
// Swapping the backend registry storage for testing
TestBackendRegistry testBackendRegistry;
+ const BackendRegistry& backendRegistry = BackendRegistryInstance();
+ BOOST_TEST(backendRegistry.Size() == 0);
+
IRuntime::CreationOptions creationOptions;
IRuntimePtr runtime = IRuntime::Create(creationOptions);
- const BackendRegistry& backendRegistry = BackendRegistryInstance();
+ const DeviceSpec& deviceSpec = *boost::polymorphic_downcast<const DeviceSpec*>(&runtime->GetDeviceSpec());
+ BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends();
+ BOOST_TEST(supportedBackendIds.empty());
+
BOOST_TEST(backendRegistry.Size() == 0);
}
@@ -1228,6 +1248,14 @@ void RuntimeDynamicBackendsTestImpl()
{
BOOST_TEST((backendIds.find(expectedRegisteredbackendId) != backendIds.end()));
}
+
+ const DeviceSpec& deviceSpec = *boost::polymorphic_downcast<const DeviceSpec*>(&runtime->GetDeviceSpec());
+ BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends();
+ BOOST_TEST(supportedBackendIds.size() == expectedRegisteredbackendIds.size());
+ for (const BackendId& expectedRegisteredbackendId : expectedRegisteredbackendIds)
+ {
+ BOOST_TEST((supportedBackendIds.find(expectedRegisteredbackendId) != supportedBackendIds.end()));
+ }
}
void RuntimeDuplicateDynamicBackendsTestImpl()
@@ -1261,6 +1289,14 @@ void RuntimeDuplicateDynamicBackendsTestImpl()
{
BOOST_TEST((backendIds.find(expectedRegisteredbackendId) != backendIds.end()));
}
+
+ const DeviceSpec& deviceSpec = *boost::polymorphic_downcast<const DeviceSpec*>(&runtime->GetDeviceSpec());
+ BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends();
+ BOOST_TEST(supportedBackendIds.size() == expectedRegisteredbackendIds.size());
+ for (const BackendId& expectedRegisteredbackendId : expectedRegisteredbackendIds)
+ {
+ BOOST_TEST((supportedBackendIds.find(expectedRegisteredbackendId) != supportedBackendIds.end()));
+ }
}
void RuntimeInvalidDynamicBackendsTestImpl()
@@ -1282,6 +1318,10 @@ void RuntimeInvalidDynamicBackendsTestImpl()
const BackendRegistry& backendRegistry = BackendRegistryInstance();
BOOST_TEST(backendRegistry.Size() == 0);
+
+ const DeviceSpec& deviceSpec = *boost::polymorphic_downcast<const DeviceSpec*>(&runtime->GetDeviceSpec());
+ BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends();
+ BOOST_TEST(supportedBackendIds.empty());
}
void RuntimeInvalidOverridePathTestImpl()
@@ -1298,6 +1338,10 @@ void RuntimeInvalidOverridePathTestImpl()
const BackendRegistry& backendRegistry = BackendRegistryInstance();
BOOST_TEST(backendRegistry.Size() == 0);
+
+ const DeviceSpec& deviceSpec = *boost::polymorphic_downcast<const DeviceSpec*>(&runtime->GetDeviceSpec());
+ BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends();
+ BOOST_TEST(supportedBackendIds.empty());
}
void CreateReferenceDynamicBackendTestImpl()
@@ -1330,6 +1374,11 @@ void CreateReferenceDynamicBackendTestImpl()
BackendIdSet backendIds = backendRegistry.GetBackendIds();
BOOST_TEST((backendIds.find("CpuRef") != backendIds.end()));
+ const DeviceSpec& deviceSpec = *boost::polymorphic_downcast<const DeviceSpec*>(&runtime->GetDeviceSpec());
+ BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends();
+ BOOST_TEST(supportedBackendIds.size() == 1);
+ BOOST_TEST((supportedBackendIds.find("CpuRef") != supportedBackendIds.end()));
+
// Get the factory function
auto referenceDynamicBackendFactoryFunction = backendRegistry.GetFactory("CpuRef");
BOOST_TEST((referenceDynamicBackendFactoryFunction != nullptr));