diff options
author | Francis Murtagh <francis.murtagh@arm.com> | 2021-04-29 14:23:04 +0100 |
---|---|---|
committer | Francis Murtagh <francis.murtagh@arm.com> | 2021-04-29 14:55:27 +0000 |
commit | 73d3e2e1616ba5dcdb0a190afba2463742bd4fcc (patch) | |
tree | 5e03f174a763c275d5874c804996048fb0b505ab /src/backends | |
parent | 4df97eb257d3fc29b7431d9cb8a054b21d5a7448 (diff) | |
download | armnn-73d3e2e1616ba5dcdb0a190afba2463742bd4fcc.tar.gz |
IVGCVSW-5819 5820 5821 Add MemorySourceFlags to TensorHandleFactoryRegistry::GetFactory
* Modify Layer::CreateTensorHandles to include MemorySource
* Modify INetworkProperties to add MemorySource
* Disable Neon/Cl fallback tests until full import implementation complete
Change-Id: Ia4fff6ea3d4bf6afca33aae358125ccaec7f9a38
Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
Diffstat (limited to 'src/backends')
6 files changed, 46 insertions, 25 deletions
diff --git a/src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp b/src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp index 0670461b54..cc8a1361a3 100644 --- a/src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp +++ b/src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp @@ -49,6 +49,20 @@ ITensorHandleFactory* TensorHandleFactoryRegistry::GetFactory(ITensorHandleFacto return nullptr; } +ITensorHandleFactory* TensorHandleFactoryRegistry::GetFactory(ITensorHandleFactory::FactoryId id, + MemorySource memSource) const +{ + for (auto& factory : m_Factories) + { + if (factory->GetId() == id && factory->GetImportFlags() == static_cast<MemorySourceFlags>(memSource)) + { + return factory.get(); + } + } + + return nullptr; +} + void TensorHandleFactoryRegistry::AquireMemory() { for (auto& mgr : m_MemoryManagers) diff --git a/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp b/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp index e9e76e73a6..525db56216 100644 --- a/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp +++ b/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp @@ -35,6 +35,10 @@ public: /// Returns nullptr if not found ITensorHandleFactory* GetFactory(ITensorHandleFactory::FactoryId id) const; + /// Overload of above allowing specification of Memory Source + ITensorHandleFactory* GetFactory(ITensorHandleFactory::FactoryId id, + MemorySource memSource) const; + /// Aquire memory required for inference void AquireMemory(); diff --git a/src/backends/backendsCommon/test/EndToEndTestImpl.hpp b/src/backends/backendsCommon/test/EndToEndTestImpl.hpp index 3a757d0c59..a5fe8c6a62 100644 --- a/src/backends/backendsCommon/test/EndToEndTestImpl.hpp +++ b/src/backends/backendsCommon/test/EndToEndTestImpl.hpp @@ -209,7 +209,7 @@ inline void ImportNonAlignedInputPointerTest(std::vector<BackendId> backends) NetworkId netId; std::string ignoredErrorMessage; // Enable Importing - INetworkProperties networkProperties(true, false); + INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Undefined); runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties); // Creates structures for input & output @@ -274,7 +274,7 @@ inline void ExportNonAlignedOutputPointerTest(std::vector<BackendId> backends) NetworkId netId; std::string ignoredErrorMessage; // Enable Importing and Exporting - INetworkProperties networkProperties(true, true); + INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc); runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties); // Creates structures for input & output @@ -345,7 +345,7 @@ inline void ImportAlignedPointerTest(std::vector<BackendId> backends) NetworkId netId; std::string ignoredErrorMessage; // Enable Importing - INetworkProperties networkProperties(true, true); + INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc); runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties); // Creates structures for input & output @@ -428,7 +428,9 @@ inline void ImportOnlyWorkload(std::vector<BackendId> backends) // Load it into the runtime. It should pass. NetworkId netId; std::string ignoredErrorMessage; - INetworkProperties networkProperties(true, false); + + INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Undefined); + BOOST_TEST(runtime->LoadNetwork(netId, std::move(optNet),ignoredErrorMessage, networkProperties) == Status::Success); @@ -516,7 +518,7 @@ inline void ExportOnlyWorkload(std::vector<BackendId> backends) // Load it into the runtime. It should pass. NetworkId netId; std::string ignoredErrorMessage; - INetworkProperties networkProperties(false, true); + INetworkProperties networkProperties(false, MemorySource::Undefined, MemorySource::Malloc); BOOST_TEST(runtime->LoadNetwork(netId, std::move(optNet),ignoredErrorMessage, networkProperties) == Status::Success); @@ -603,7 +605,9 @@ inline void ImportAndExportWorkload(std::vector<BackendId> backends) // Load it into the runtime. It should pass. NetworkId netId; std::string ignoredErrorMessage; - INetworkProperties networkProperties(true, true); + + INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc); + BOOST_TEST(runtime->LoadNetwork(netId, std::move(optNet),ignoredErrorMessage, networkProperties) == Status::Success); @@ -694,7 +698,7 @@ inline void ExportOutputWithSeveralOutputSlotConnectionsTest(std::vector<Backend NetworkId netId; std::string ignoredErrorMessage; // Enable Importing - INetworkProperties networkProperties(true, true); + INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc); runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties); // Creates structures for input & output diff --git a/src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp b/src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp index 16b10c88ac..b20ff4f142 100644 --- a/src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp +++ b/src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp @@ -42,7 +42,7 @@ void AsyncThreadedEndToEndTestImpl(INetworkPtr network, // Creates AsyncNetwork NetworkId networkId = 0; std::string errorMessage; - const INetworkProperties networkProperties(false, false, true); + const INetworkProperties networkProperties(true, MemorySource::Undefined, MemorySource::Undefined); runtime->LoadNetwork(networkId, std::move(optNet), errorMessage, networkProperties); std::vector<InputTensors> inputTensorsVec; @@ -134,7 +134,7 @@ void AsyncEndToEndTestImpl(INetworkPtr network, // Creates AsyncNetwork NetworkId networkId = 0; std::string errorMessage; - const INetworkProperties networkProperties(false, false, true); + const INetworkProperties networkProperties(true, MemorySource::Undefined, MemorySource::Undefined); runtime->LoadNetwork(networkId, std::move(optNet), errorMessage, networkProperties); InputTensors inputTensors; diff --git a/src/backends/cl/test/ClFallbackTests.cpp b/src/backends/cl/test/ClFallbackTests.cpp index 4384ae5fec..eec3afe447 100644 --- a/src/backends/cl/test/ClFallbackTests.cpp +++ b/src/backends/cl/test/ClFallbackTests.cpp @@ -11,7 +11,7 @@ BOOST_AUTO_TEST_SUITE(ClFallback) -BOOST_AUTO_TEST_CASE(ClImportEnabledFallbackToNeon) +BOOST_AUTO_TEST_CASE(ClImportEnabledFallbackToNeon, * boost::unit_test::disabled()) { using namespace armnn; @@ -78,8 +78,7 @@ BOOST_AUTO_TEST_CASE(ClImportEnabledFallbackToNeon) // Load it into the runtime. It should pass. NetworkId netId; std::string ignoredErrorMessage; - INetworkProperties networkProperties(true, true); - + INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc); runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties); // Creates structures for input & output @@ -259,7 +258,7 @@ BOOST_AUTO_TEST_CASE(ClImportDisabledFallbackToNeon) BOOST_TEST(outputData == expectedOutput); } -BOOST_AUTO_TEST_CASE(ClImportEnabledFallbackSubgraphToNeon) +BOOST_AUTO_TEST_CASE(ClImportEnabledFallbackSubgraphToNeon, * boost::unit_test::disabled()) { using namespace armnn; @@ -337,8 +336,7 @@ BOOST_AUTO_TEST_CASE(ClImportEnabledFallbackSubgraphToNeon) // Load it into the runtime. It should pass. NetworkId netId; std::string ignoredErrorMessage; - INetworkProperties networkProperties(true, true); - + INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc); runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties); // Creates structures for input & output diff --git a/src/backends/neon/test/NeonFallbackTests.cpp b/src/backends/neon/test/NeonFallbackTests.cpp index 2d70cc2b1b..8dc592db5d 100644 --- a/src/backends/neon/test/NeonFallbackTests.cpp +++ b/src/backends/neon/test/NeonFallbackTests.cpp @@ -83,8 +83,7 @@ BOOST_AUTO_TEST_CASE(FallbackImportToCpuAcc) // Load it into the runtime. It should pass. NetworkId netId; std::string ignoredErrorMessage; - INetworkProperties networkProperties(true, true); - + INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc); runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties); // Creates structures for input & output @@ -218,7 +217,7 @@ BOOST_AUTO_TEST_CASE(FallbackPaddingCopyToCpuAcc) // Load it into the runtime. It should pass. NetworkId netId; std::string ignoredErrorMessage; - INetworkProperties networkProperties(true, true); + INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc); runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties); @@ -350,8 +349,8 @@ BOOST_AUTO_TEST_CASE(FallbackImportFromCpuAcc) // Load it into the runtime. It should pass. NetworkId netId; std::string ignoredErrorMessage; - INetworkProperties networkProperties(true, true); + INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc); runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties); // Creates structures for input & output @@ -485,7 +484,7 @@ BOOST_AUTO_TEST_CASE(FallbackPaddingCopyFromCpuAcc) // Load it into the runtime. It should pass. NetworkId netId; std::string ignoredErrorMessage; - INetworkProperties networkProperties(true, true); + INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc); runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties); @@ -615,7 +614,7 @@ BOOST_AUTO_TEST_CASE(FallbackDisableImportFromCpuAcc) // Load it into the runtime. It should pass. NetworkId netId; std::string ignoredErrorMessage; - INetworkProperties networkProperties(false, false); + INetworkProperties networkProperties(false, MemorySource::Undefined, MemorySource::Undefined); runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties); @@ -678,7 +677,7 @@ BOOST_AUTO_TEST_CASE(FallbackDisableImportFromCpuAcc) } #if defined(ARMCOMPUTECL_ENABLED) -BOOST_AUTO_TEST_CASE(NeonImportEnabledFallbackToCl) +BOOST_AUTO_TEST_CASE(NeonImportEnabledFallbackToCl, * boost::unit_test::disabled()) { using namespace armnn; @@ -745,7 +744,8 @@ BOOST_AUTO_TEST_CASE(NeonImportEnabledFallbackToCl) // Load it into the runtime. It should pass. NetworkId netId; std::string ignoredErrorMessage; - INetworkProperties networkProperties(true, true); + + INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc); runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties); @@ -926,7 +926,7 @@ BOOST_AUTO_TEST_CASE(NeonImportDisabledFallbackToCl) BOOST_TEST(outputData == expectedOutput); } -BOOST_AUTO_TEST_CASE(NeonImportEnabledFallbackSubgraphToCl) +BOOST_AUTO_TEST_CASE(NeonImportEnabledFallbackSubgraphToCl, * boost::unit_test::disabled()) { using namespace armnn; @@ -1004,7 +1004,8 @@ BOOST_AUTO_TEST_CASE(NeonImportEnabledFallbackSubgraphToCl) // Load it into the runtime. It should pass. NetworkId netId; std::string ignoredErrorMessage; - INetworkProperties networkProperties(true, true); + + INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc); runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties); |