From 895339092fa9edc0aa59de0309f79bebacc3fa63 Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Thu, 15 Aug 2019 12:08:06 +0100 Subject: IVGCVSW-3545 Update the device specs with the dynamic backend ids * Now the utility function RegisterDynamicBackends returns a list of the backend ids that have been registered * The list of registered ids is added to the list of supported backends in the Runtime * Added unit tests Change-Id: I97bbe1f680920358f5baba5a4666e4983b849cac Signed-off-by: Matteo Martincigh --- src/armnn/DeviceSpec.hpp | 7 ++- src/armnn/Runtime.cpp | 8 ++- .../backendsCommon/DynamicBackendUtils.cpp | 30 ++++++++-- .../backendsCommon/DynamicBackendUtils.hpp | 6 +- .../backendsCommon/test/DynamicBackendTests.hpp | 65 +++++++++++++++++++--- 5 files changed, 95 insertions(+), 21 deletions(-) diff --git a/src/armnn/DeviceSpec.hpp b/src/armnn/DeviceSpec.hpp index 35923e6f9d..32264706fd 100644 --- a/src/armnn/DeviceSpec.hpp +++ b/src/armnn/DeviceSpec.hpp @@ -24,9 +24,14 @@ public: return m_SupportedBackends; } + void AddSupportedBackends(const BackendIdSet& backendIds) + { + m_SupportedBackends.insert(backendIds.begin(), backendIds.end()); + } + private: DeviceSpec() = delete; BackendIdSet m_SupportedBackends; }; -} +} // namespace armnn diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp index 6b91863deb..9e874848ec 100644 --- a/src/armnn/Runtime.cpp +++ b/src/armnn/Runtime.cpp @@ -144,8 +144,7 @@ Runtime::Runtime(const CreationOptions& options) { // Store backend contexts for the supported ones const BackendIdSet& supportedBackends = m_DeviceSpec.GetSupportedBackends(); - auto it = supportedBackends.find(id); - if (it != supportedBackends.end()) + if (supportedBackends.find(id) != supportedBackends.end()) { auto factoryFun = BackendRegistryInstance().GetFactory(id); auto backend = factoryFun(); @@ -257,7 +256,10 @@ void Runtime::LoadDynamicBackends(const std::string& overrideBackendPath) m_DynamicBackends = DynamicBackendUtils::CreateDynamicBackends(sharedObjects); // Register the dynamic backends in the backend registry - DynamicBackendUtils::RegisterDynamicBackends(m_DynamicBackends); + BackendIdSet registeredBackendIds = DynamicBackendUtils::RegisterDynamicBackends(m_DynamicBackends); + + // Add the registered dynamic backend ids to the list of supported backends + m_DeviceSpec.AddSupportedBackends(registeredBackendIds); } } // namespace armnn diff --git a/src/backends/backendsCommon/DynamicBackendUtils.cpp b/src/backends/backendsCommon/DynamicBackendUtils.cpp index fadec0c389..fc4336f4ac 100644 --- a/src/backends/backendsCommon/DynamicBackendUtils.cpp +++ b/src/backends/backendsCommon/DynamicBackendUtils.cpp @@ -299,21 +299,25 @@ std::vector DynamicBackendUtils::CreateDynamicBackends(const return dynamicBackends; } -void DynamicBackendUtils::RegisterDynamicBackends(const std::vector& dynamicBackends) +BackendIdSet DynamicBackendUtils::RegisterDynamicBackends(const std::vector& dynamicBackends) { // Get a reference of the backend registry BackendRegistry& backendRegistry = BackendRegistryInstance(); - // Register the dynamic backends in the backend registry - RegisterDynamicBackendsImpl(backendRegistry, dynamicBackends); + // Register the dynamic backends in the backend registry, and return a list of registered backend ids + return RegisterDynamicBackendsImpl(backendRegistry, dynamicBackends); } -void DynamicBackendUtils::RegisterDynamicBackendsImpl(BackendRegistry& backendRegistry, - const std::vector& dynamicBackends) +BackendIdSet DynamicBackendUtils::RegisterDynamicBackendsImpl(BackendRegistry& backendRegistry, + const std::vector& dynamicBackends) { + // Initialize the list of registered backend ids + BackendIdSet registeredBackendIds; + // Register the dynamic backends in the backend registry for (const DynamicBackendPtr& dynamicBackend : dynamicBackends) { + // Get the id of the dynamic backend BackendId dynamicBackendId; try { @@ -362,8 +366,22 @@ void DynamicBackendUtils::RegisterDynamicBackendsImpl(BackendRegistry& backendRe } // Register the dynamic backend - backendRegistry.Register(dynamicBackendId, dynamicBackendFactoryFunction); + try + { + backendRegistry.Register(dynamicBackendId, dynamicBackendFactoryFunction); + } + catch (const InvalidArgumentException& e) + { + BOOST_LOG_TRIVIAL(warning) << "An error has occurred when registering the dynamic backend \"" + << dynamicBackendId << "\": " << e.what(); + continue; + } + + // Add the id of the dynamic backend just registered to the list of registered backend ids + registeredBackendIds.insert(dynamicBackendId); } + + return registeredBackendIds; } } // namespace armnn diff --git a/src/backends/backendsCommon/DynamicBackendUtils.hpp b/src/backends/backendsCommon/DynamicBackendUtils.hpp index 187b0b1eab..0aa0ac8da5 100644 --- a/src/backends/backendsCommon/DynamicBackendUtils.hpp +++ b/src/backends/backendsCommon/DynamicBackendUtils.hpp @@ -39,14 +39,14 @@ public: static std::vector GetSharedObjects(const std::vector& backendPaths); static std::vector CreateDynamicBackends(const std::vector& sharedObjects); - static void RegisterDynamicBackends(const std::vector& dynamicBackends); + static BackendIdSet RegisterDynamicBackends(const std::vector& dynamicBackends); protected: /// Protected methods for testing purposes static bool IsBackendCompatibleImpl(const BackendVersion& backendApiVersion, const BackendVersion& backendVersion); static std::vector GetBackendPathsImpl(const std::string& backendPaths); - static void RegisterDynamicBackendsImpl(BackendRegistry& backendRegistry, - const std::vector& dynamicBackends); + static BackendIdSet RegisterDynamicBackendsImpl(BackendRegistry& backendRegistry, + const std::vector& dynamicBackends); private: static std::string GetDlError(); diff --git a/src/backends/backendsCommon/test/DynamicBackendTests.hpp b/src/backends/backendsCommon/test/DynamicBackendTests.hpp index 74ef6f1ba7..e225124e01 100644 --- a/src/backends/backendsCommon/test/DynamicBackendTests.hpp +++ b/src/backends/backendsCommon/test/DynamicBackendTests.hpp @@ -79,10 +79,11 @@ public: return GetBackendPathsImpl(path); } - static void RegisterDynamicBackendsImplTest(armnn::BackendRegistry& backendRegistry, - const std::vector& dynamicBackends) + static armnn::BackendIdSet RegisterDynamicBackendsImplTest( + armnn::BackendRegistry& backendRegistry, + const std::vector& dynamicBackends) { - RegisterDynamicBackendsImpl(backendRegistry, dynamicBackends); + return RegisterDynamicBackendsImpl(backendRegistry, dynamicBackends); } }; @@ -896,12 +897,15 @@ void RegisterSingleDynamicBackendTestImpl() BackendVersion dynamicBackendVersion = dynamicBackends[0]->GetBackendVersion(); BOOST_TEST(TestDynamicBackendUtils::IsBackendCompatible(dynamicBackendVersion)); - TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, dynamicBackends); + BackendIdSet registeredBackendIds = TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, + dynamicBackends); BOOST_TEST(backendRegistry.Size() == 1); + BOOST_TEST(registeredBackendIds.size() == 1); BackendIdSet backendIds = backendRegistry.GetBackendIds(); BOOST_TEST(backendIds.size() == 1); BOOST_TEST((backendIds.find(dynamicBackendId) != backendIds.end())); + BOOST_TEST((registeredBackendIds.find(dynamicBackendId) != registeredBackendIds.end())); auto dynamicBackendFactoryFunction = backendRegistry.GetFactory(dynamicBackendId); BOOST_TEST((dynamicBackendFactoryFunction != nullptr)); @@ -960,14 +964,19 @@ void RegisterMultipleDynamicBackendsTestImpl() BackendRegistry backendRegistry; BOOST_TEST(backendRegistry.Size() == 0); - TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, dynamicBackends); + BackendIdSet registeredBackendIds = TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, + dynamicBackends); BOOST_TEST(backendRegistry.Size() == 3); + BOOST_TEST(registeredBackendIds.size() == 3); BackendIdSet backendIds = backendRegistry.GetBackendIds(); BOOST_TEST(backendIds.size() == 3); BOOST_TEST((backendIds.find(dynamicBackendId1) != backendIds.end())); BOOST_TEST((backendIds.find(dynamicBackendId2) != backendIds.end())); BOOST_TEST((backendIds.find(dynamicBackendId3) != backendIds.end())); + BOOST_TEST((registeredBackendIds.find(dynamicBackendId1) != registeredBackendIds.end())); + BOOST_TEST((registeredBackendIds.find(dynamicBackendId2) != registeredBackendIds.end())); + BOOST_TEST((registeredBackendIds.find(dynamicBackendId3) != registeredBackendIds.end())); for (size_t i = 0; i < dynamicBackends.size(); i++) { @@ -1036,8 +1045,10 @@ void RegisterMultipleInvalidDynamicBackendsTestImpl() BOOST_TEST(backendRegistry.Size() == 0); // Check that no dynamic backend got registered - TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, dynamicBackends); + BackendIdSet registeredBackendIds = TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, + dynamicBackends); BOOST_TEST(backendRegistry.Size() == 0); + BOOST_TEST(registeredBackendIds.empty()); } void RegisterMixedDynamicBackendsTestImpl() @@ -1165,14 +1176,17 @@ void RegisterMixedDynamicBackendsTestImpl() "TestValid5" }; - TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, dynamicBackends); + BackendIdSet registeredBackendIds = TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, + dynamicBackends); BOOST_TEST(backendRegistry.Size() == expectedRegisteredbackendIds.size()); + BOOST_TEST(registeredBackendIds.size() == expectedRegisteredbackendIds.size()); BackendIdSet backendIds = backendRegistry.GetBackendIds(); BOOST_TEST(backendIds.size() == expectedRegisteredbackendIds.size()); for (const BackendId& expectedRegisteredbackendId : expectedRegisteredbackendIds) { BOOST_TEST((backendIds.find(expectedRegisteredbackendId) != backendIds.end())); + BOOST_TEST((registeredBackendIds.find(expectedRegisteredbackendId) != registeredBackendIds.end())); auto dynamicBackendFactoryFunction = backendRegistry.GetFactory(expectedRegisteredbackendId); BOOST_TEST((dynamicBackendFactoryFunction != nullptr)); @@ -1190,10 +1204,16 @@ void RuntimeEmptyTestImpl() // Swapping the backend registry storage for testing TestBackendRegistry testBackendRegistry; + const BackendRegistry& backendRegistry = BackendRegistryInstance(); + BOOST_TEST(backendRegistry.Size() == 0); + IRuntime::CreationOptions creationOptions; IRuntimePtr runtime = IRuntime::Create(creationOptions); - const BackendRegistry& backendRegistry = BackendRegistryInstance(); + const DeviceSpec& deviceSpec = *boost::polymorphic_downcast(&runtime->GetDeviceSpec()); + BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends(); + BOOST_TEST(supportedBackendIds.empty()); + BOOST_TEST(backendRegistry.Size() == 0); } @@ -1228,6 +1248,14 @@ void RuntimeDynamicBackendsTestImpl() { BOOST_TEST((backendIds.find(expectedRegisteredbackendId) != backendIds.end())); } + + const DeviceSpec& deviceSpec = *boost::polymorphic_downcast(&runtime->GetDeviceSpec()); + BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends(); + BOOST_TEST(supportedBackendIds.size() == expectedRegisteredbackendIds.size()); + for (const BackendId& expectedRegisteredbackendId : expectedRegisteredbackendIds) + { + BOOST_TEST((supportedBackendIds.find(expectedRegisteredbackendId) != supportedBackendIds.end())); + } } void RuntimeDuplicateDynamicBackendsTestImpl() @@ -1261,6 +1289,14 @@ void RuntimeDuplicateDynamicBackendsTestImpl() { BOOST_TEST((backendIds.find(expectedRegisteredbackendId) != backendIds.end())); } + + const DeviceSpec& deviceSpec = *boost::polymorphic_downcast(&runtime->GetDeviceSpec()); + BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends(); + BOOST_TEST(supportedBackendIds.size() == expectedRegisteredbackendIds.size()); + for (const BackendId& expectedRegisteredbackendId : expectedRegisteredbackendIds) + { + BOOST_TEST((supportedBackendIds.find(expectedRegisteredbackendId) != supportedBackendIds.end())); + } } void RuntimeInvalidDynamicBackendsTestImpl() @@ -1282,6 +1318,10 @@ void RuntimeInvalidDynamicBackendsTestImpl() const BackendRegistry& backendRegistry = BackendRegistryInstance(); BOOST_TEST(backendRegistry.Size() == 0); + + const DeviceSpec& deviceSpec = *boost::polymorphic_downcast(&runtime->GetDeviceSpec()); + BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends(); + BOOST_TEST(supportedBackendIds.empty()); } void RuntimeInvalidOverridePathTestImpl() @@ -1298,6 +1338,10 @@ void RuntimeInvalidOverridePathTestImpl() const BackendRegistry& backendRegistry = BackendRegistryInstance(); BOOST_TEST(backendRegistry.Size() == 0); + + const DeviceSpec& deviceSpec = *boost::polymorphic_downcast(&runtime->GetDeviceSpec()); + BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends(); + BOOST_TEST(supportedBackendIds.empty()); } void CreateReferenceDynamicBackendTestImpl() @@ -1330,6 +1374,11 @@ void CreateReferenceDynamicBackendTestImpl() BackendIdSet backendIds = backendRegistry.GetBackendIds(); BOOST_TEST((backendIds.find("CpuRef") != backendIds.end())); + const DeviceSpec& deviceSpec = *boost::polymorphic_downcast(&runtime->GetDeviceSpec()); + BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends(); + BOOST_TEST(supportedBackendIds.size() == 1); + BOOST_TEST((supportedBackendIds.find("CpuRef") != supportedBackendIds.end())); + // Get the factory function auto referenceDynamicBackendFactoryFunction = backendRegistry.GetFactory("CpuRef"); BOOST_TEST((referenceDynamicBackendFactoryFunction != nullptr)); -- cgit v1.2.1