From 895339092fa9edc0aa59de0309f79bebacc3fa63 Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Thu, 15 Aug 2019 12:08:06 +0100 Subject: IVGCVSW-3545 Update the device specs with the dynamic backend ids * Now the utility function RegisterDynamicBackends returns a list of the backend ids that have been registered * The list of registered ids is added to the list of supported backends in the Runtime * Added unit tests Change-Id: I97bbe1f680920358f5baba5a4666e4983b849cac Signed-off-by: Matteo Martincigh --- .../backendsCommon/DynamicBackendUtils.cpp | 30 ++++++++-- .../backendsCommon/DynamicBackendUtils.hpp | 6 +- .../backendsCommon/test/DynamicBackendTests.hpp | 65 +++++++++++++++++++--- 3 files changed, 84 insertions(+), 17 deletions(-) (limited to 'src/backends/backendsCommon') diff --git a/src/backends/backendsCommon/DynamicBackendUtils.cpp b/src/backends/backendsCommon/DynamicBackendUtils.cpp index fadec0c389..fc4336f4ac 100644 --- a/src/backends/backendsCommon/DynamicBackendUtils.cpp +++ b/src/backends/backendsCommon/DynamicBackendUtils.cpp @@ -299,21 +299,25 @@ std::vector DynamicBackendUtils::CreateDynamicBackends(const return dynamicBackends; } -void DynamicBackendUtils::RegisterDynamicBackends(const std::vector& dynamicBackends) +BackendIdSet DynamicBackendUtils::RegisterDynamicBackends(const std::vector& dynamicBackends) { // Get a reference of the backend registry BackendRegistry& backendRegistry = BackendRegistryInstance(); - // Register the dynamic backends in the backend registry - RegisterDynamicBackendsImpl(backendRegistry, dynamicBackends); + // Register the dynamic backends in the backend registry, and return a list of registered backend ids + return RegisterDynamicBackendsImpl(backendRegistry, dynamicBackends); } -void DynamicBackendUtils::RegisterDynamicBackendsImpl(BackendRegistry& backendRegistry, - const std::vector& dynamicBackends) +BackendIdSet DynamicBackendUtils::RegisterDynamicBackendsImpl(BackendRegistry& backendRegistry, + const std::vector& dynamicBackends) { + // Initialize the list of registered backend ids + BackendIdSet registeredBackendIds; + // Register the dynamic backends in the backend registry for (const DynamicBackendPtr& dynamicBackend : dynamicBackends) { + // Get the id of the dynamic backend BackendId dynamicBackendId; try { @@ -362,8 +366,22 @@ void DynamicBackendUtils::RegisterDynamicBackendsImpl(BackendRegistry& backendRe } // Register the dynamic backend - backendRegistry.Register(dynamicBackendId, dynamicBackendFactoryFunction); + try + { + backendRegistry.Register(dynamicBackendId, dynamicBackendFactoryFunction); + } + catch (const InvalidArgumentException& e) + { + BOOST_LOG_TRIVIAL(warning) << "An error has occurred when registering the dynamic backend \"" + << dynamicBackendId << "\": " << e.what(); + continue; + } + + // Add the id of the dynamic backend just registered to the list of registered backend ids + registeredBackendIds.insert(dynamicBackendId); } + + return registeredBackendIds; } } // namespace armnn diff --git a/src/backends/backendsCommon/DynamicBackendUtils.hpp b/src/backends/backendsCommon/DynamicBackendUtils.hpp index 187b0b1eab..0aa0ac8da5 100644 --- a/src/backends/backendsCommon/DynamicBackendUtils.hpp +++ b/src/backends/backendsCommon/DynamicBackendUtils.hpp @@ -39,14 +39,14 @@ public: static std::vector GetSharedObjects(const std::vector& backendPaths); static std::vector CreateDynamicBackends(const std::vector& sharedObjects); - static void RegisterDynamicBackends(const std::vector& dynamicBackends); + static BackendIdSet RegisterDynamicBackends(const std::vector& dynamicBackends); protected: /// Protected methods for testing purposes static bool IsBackendCompatibleImpl(const BackendVersion& backendApiVersion, const BackendVersion& backendVersion); static std::vector GetBackendPathsImpl(const std::string& backendPaths); - static void RegisterDynamicBackendsImpl(BackendRegistry& backendRegistry, - const std::vector& dynamicBackends); + static BackendIdSet RegisterDynamicBackendsImpl(BackendRegistry& backendRegistry, + const std::vector& dynamicBackends); private: static std::string GetDlError(); diff --git a/src/backends/backendsCommon/test/DynamicBackendTests.hpp b/src/backends/backendsCommon/test/DynamicBackendTests.hpp index 74ef6f1ba7..e225124e01 100644 --- a/src/backends/backendsCommon/test/DynamicBackendTests.hpp +++ b/src/backends/backendsCommon/test/DynamicBackendTests.hpp @@ -79,10 +79,11 @@ public: return GetBackendPathsImpl(path); } - static void RegisterDynamicBackendsImplTest(armnn::BackendRegistry& backendRegistry, - const std::vector& dynamicBackends) + static armnn::BackendIdSet RegisterDynamicBackendsImplTest( + armnn::BackendRegistry& backendRegistry, + const std::vector& dynamicBackends) { - RegisterDynamicBackendsImpl(backendRegistry, dynamicBackends); + return RegisterDynamicBackendsImpl(backendRegistry, dynamicBackends); } }; @@ -896,12 +897,15 @@ void RegisterSingleDynamicBackendTestImpl() BackendVersion dynamicBackendVersion = dynamicBackends[0]->GetBackendVersion(); BOOST_TEST(TestDynamicBackendUtils::IsBackendCompatible(dynamicBackendVersion)); - TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, dynamicBackends); + BackendIdSet registeredBackendIds = TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, + dynamicBackends); BOOST_TEST(backendRegistry.Size() == 1); + BOOST_TEST(registeredBackendIds.size() == 1); BackendIdSet backendIds = backendRegistry.GetBackendIds(); BOOST_TEST(backendIds.size() == 1); BOOST_TEST((backendIds.find(dynamicBackendId) != backendIds.end())); + BOOST_TEST((registeredBackendIds.find(dynamicBackendId) != registeredBackendIds.end())); auto dynamicBackendFactoryFunction = backendRegistry.GetFactory(dynamicBackendId); BOOST_TEST((dynamicBackendFactoryFunction != nullptr)); @@ -960,14 +964,19 @@ void RegisterMultipleDynamicBackendsTestImpl() BackendRegistry backendRegistry; BOOST_TEST(backendRegistry.Size() == 0); - TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, dynamicBackends); + BackendIdSet registeredBackendIds = TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, + dynamicBackends); BOOST_TEST(backendRegistry.Size() == 3); + BOOST_TEST(registeredBackendIds.size() == 3); BackendIdSet backendIds = backendRegistry.GetBackendIds(); BOOST_TEST(backendIds.size() == 3); BOOST_TEST((backendIds.find(dynamicBackendId1) != backendIds.end())); BOOST_TEST((backendIds.find(dynamicBackendId2) != backendIds.end())); BOOST_TEST((backendIds.find(dynamicBackendId3) != backendIds.end())); + BOOST_TEST((registeredBackendIds.find(dynamicBackendId1) != registeredBackendIds.end())); + BOOST_TEST((registeredBackendIds.find(dynamicBackendId2) != registeredBackendIds.end())); + BOOST_TEST((registeredBackendIds.find(dynamicBackendId3) != registeredBackendIds.end())); for (size_t i = 0; i < dynamicBackends.size(); i++) { @@ -1036,8 +1045,10 @@ void RegisterMultipleInvalidDynamicBackendsTestImpl() BOOST_TEST(backendRegistry.Size() == 0); // Check that no dynamic backend got registered - TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, dynamicBackends); + BackendIdSet registeredBackendIds = TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, + dynamicBackends); BOOST_TEST(backendRegistry.Size() == 0); + BOOST_TEST(registeredBackendIds.empty()); } void RegisterMixedDynamicBackendsTestImpl() @@ -1165,14 +1176,17 @@ void RegisterMixedDynamicBackendsTestImpl() "TestValid5" }; - TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, dynamicBackends); + BackendIdSet registeredBackendIds = TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, + dynamicBackends); BOOST_TEST(backendRegistry.Size() == expectedRegisteredbackendIds.size()); + BOOST_TEST(registeredBackendIds.size() == expectedRegisteredbackendIds.size()); BackendIdSet backendIds = backendRegistry.GetBackendIds(); BOOST_TEST(backendIds.size() == expectedRegisteredbackendIds.size()); for (const BackendId& expectedRegisteredbackendId : expectedRegisteredbackendIds) { BOOST_TEST((backendIds.find(expectedRegisteredbackendId) != backendIds.end())); + BOOST_TEST((registeredBackendIds.find(expectedRegisteredbackendId) != registeredBackendIds.end())); auto dynamicBackendFactoryFunction = backendRegistry.GetFactory(expectedRegisteredbackendId); BOOST_TEST((dynamicBackendFactoryFunction != nullptr)); @@ -1190,10 +1204,16 @@ void RuntimeEmptyTestImpl() // Swapping the backend registry storage for testing TestBackendRegistry testBackendRegistry; + const BackendRegistry& backendRegistry = BackendRegistryInstance(); + BOOST_TEST(backendRegistry.Size() == 0); + IRuntime::CreationOptions creationOptions; IRuntimePtr runtime = IRuntime::Create(creationOptions); - const BackendRegistry& backendRegistry = BackendRegistryInstance(); + const DeviceSpec& deviceSpec = *boost::polymorphic_downcast(&runtime->GetDeviceSpec()); + BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends(); + BOOST_TEST(supportedBackendIds.empty()); + BOOST_TEST(backendRegistry.Size() == 0); } @@ -1228,6 +1248,14 @@ void RuntimeDynamicBackendsTestImpl() { BOOST_TEST((backendIds.find(expectedRegisteredbackendId) != backendIds.end())); } + + const DeviceSpec& deviceSpec = *boost::polymorphic_downcast(&runtime->GetDeviceSpec()); + BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends(); + BOOST_TEST(supportedBackendIds.size() == expectedRegisteredbackendIds.size()); + for (const BackendId& expectedRegisteredbackendId : expectedRegisteredbackendIds) + { + BOOST_TEST((supportedBackendIds.find(expectedRegisteredbackendId) != supportedBackendIds.end())); + } } void RuntimeDuplicateDynamicBackendsTestImpl() @@ -1261,6 +1289,14 @@ void RuntimeDuplicateDynamicBackendsTestImpl() { BOOST_TEST((backendIds.find(expectedRegisteredbackendId) != backendIds.end())); } + + const DeviceSpec& deviceSpec = *boost::polymorphic_downcast(&runtime->GetDeviceSpec()); + BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends(); + BOOST_TEST(supportedBackendIds.size() == expectedRegisteredbackendIds.size()); + for (const BackendId& expectedRegisteredbackendId : expectedRegisteredbackendIds) + { + BOOST_TEST((supportedBackendIds.find(expectedRegisteredbackendId) != supportedBackendIds.end())); + } } void RuntimeInvalidDynamicBackendsTestImpl() @@ -1282,6 +1318,10 @@ void RuntimeInvalidDynamicBackendsTestImpl() const BackendRegistry& backendRegistry = BackendRegistryInstance(); BOOST_TEST(backendRegistry.Size() == 0); + + const DeviceSpec& deviceSpec = *boost::polymorphic_downcast(&runtime->GetDeviceSpec()); + BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends(); + BOOST_TEST(supportedBackendIds.empty()); } void RuntimeInvalidOverridePathTestImpl() @@ -1298,6 +1338,10 @@ void RuntimeInvalidOverridePathTestImpl() const BackendRegistry& backendRegistry = BackendRegistryInstance(); BOOST_TEST(backendRegistry.Size() == 0); + + const DeviceSpec& deviceSpec = *boost::polymorphic_downcast(&runtime->GetDeviceSpec()); + BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends(); + BOOST_TEST(supportedBackendIds.empty()); } void CreateReferenceDynamicBackendTestImpl() @@ -1330,6 +1374,11 @@ void CreateReferenceDynamicBackendTestImpl() BackendIdSet backendIds = backendRegistry.GetBackendIds(); BOOST_TEST((backendIds.find("CpuRef") != backendIds.end())); + const DeviceSpec& deviceSpec = *boost::polymorphic_downcast(&runtime->GetDeviceSpec()); + BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends(); + BOOST_TEST(supportedBackendIds.size() == 1); + BOOST_TEST((supportedBackendIds.find("CpuRef") != supportedBackendIds.end())); + // Get the factory function auto referenceDynamicBackendFactoryFunction = backendRegistry.GetFactory("CpuRef"); BOOST_TEST((referenceDynamicBackendFactoryFunction != nullptr)); -- cgit v1.2.1