aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon
diff options
context:
space:
mode:
authorFrancis Murtagh <francis.murtagh@arm.com>2021-04-29 14:23:04 +0100
committerFrancis Murtagh <francis.murtagh@arm.com>2021-04-29 14:55:27 +0000
commit73d3e2e1616ba5dcdb0a190afba2463742bd4fcc (patch)
tree5e03f174a763c275d5874c804996048fb0b505ab /src/backends/backendsCommon
parent4df97eb257d3fc29b7431d9cb8a054b21d5a7448 (diff)
downloadarmnn-73d3e2e1616ba5dcdb0a190afba2463742bd4fcc.tar.gz
IVGCVSW-5819 5820 5821 Add MemorySourceFlags to TensorHandleFactoryRegistry::GetFactory
* Modify Layer::CreateTensorHandles to include MemorySource * Modify INetworkProperties to add MemorySource * Disable Neon/Cl fallback tests until full import implementation complete Change-Id: Ia4fff6ea3d4bf6afca33aae358125ccaec7f9a38 Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r--src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp14
-rw-r--r--src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp4
-rw-r--r--src/backends/backendsCommon/test/EndToEndTestImpl.hpp18
-rw-r--r--src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp4
4 files changed, 31 insertions, 9 deletions
diff --git a/src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp b/src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp
index 0670461b54..cc8a1361a3 100644
--- a/src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp
+++ b/src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp
@@ -49,6 +49,20 @@ ITensorHandleFactory* TensorHandleFactoryRegistry::GetFactory(ITensorHandleFacto
return nullptr;
}
+ITensorHandleFactory* TensorHandleFactoryRegistry::GetFactory(ITensorHandleFactory::FactoryId id,
+ MemorySource memSource) const
+{
+ for (auto& factory : m_Factories)
+ {
+ if (factory->GetId() == id && factory->GetImportFlags() == static_cast<MemorySourceFlags>(memSource))
+ {
+ return factory.get();
+ }
+ }
+
+ return nullptr;
+}
+
void TensorHandleFactoryRegistry::AquireMemory()
{
for (auto& mgr : m_MemoryManagers)
diff --git a/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp b/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp
index e9e76e73a6..525db56216 100644
--- a/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp
+++ b/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp
@@ -35,6 +35,10 @@ public:
/// Returns nullptr if not found
ITensorHandleFactory* GetFactory(ITensorHandleFactory::FactoryId id) const;
+ /// Overload of above allowing specification of Memory Source
+ ITensorHandleFactory* GetFactory(ITensorHandleFactory::FactoryId id,
+ MemorySource memSource) const;
+
/// Aquire memory required for inference
void AquireMemory();
diff --git a/src/backends/backendsCommon/test/EndToEndTestImpl.hpp b/src/backends/backendsCommon/test/EndToEndTestImpl.hpp
index 3a757d0c59..a5fe8c6a62 100644
--- a/src/backends/backendsCommon/test/EndToEndTestImpl.hpp
+++ b/src/backends/backendsCommon/test/EndToEndTestImpl.hpp
@@ -209,7 +209,7 @@ inline void ImportNonAlignedInputPointerTest(std::vector<BackendId> backends)
NetworkId netId;
std::string ignoredErrorMessage;
// Enable Importing
- INetworkProperties networkProperties(true, false);
+ INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Undefined);
runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties);
// Creates structures for input & output
@@ -274,7 +274,7 @@ inline void ExportNonAlignedOutputPointerTest(std::vector<BackendId> backends)
NetworkId netId;
std::string ignoredErrorMessage;
// Enable Importing and Exporting
- INetworkProperties networkProperties(true, true);
+ INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc);
runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties);
// Creates structures for input & output
@@ -345,7 +345,7 @@ inline void ImportAlignedPointerTest(std::vector<BackendId> backends)
NetworkId netId;
std::string ignoredErrorMessage;
// Enable Importing
- INetworkProperties networkProperties(true, true);
+ INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc);
runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties);
// Creates structures for input & output
@@ -428,7 +428,9 @@ inline void ImportOnlyWorkload(std::vector<BackendId> backends)
// Load it into the runtime. It should pass.
NetworkId netId;
std::string ignoredErrorMessage;
- INetworkProperties networkProperties(true, false);
+
+ INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Undefined);
+
BOOST_TEST(runtime->LoadNetwork(netId, std::move(optNet),ignoredErrorMessage, networkProperties)
== Status::Success);
@@ -516,7 +518,7 @@ inline void ExportOnlyWorkload(std::vector<BackendId> backends)
// Load it into the runtime. It should pass.
NetworkId netId;
std::string ignoredErrorMessage;
- INetworkProperties networkProperties(false, true);
+ INetworkProperties networkProperties(false, MemorySource::Undefined, MemorySource::Malloc);
BOOST_TEST(runtime->LoadNetwork(netId, std::move(optNet),ignoredErrorMessage, networkProperties)
== Status::Success);
@@ -603,7 +605,9 @@ inline void ImportAndExportWorkload(std::vector<BackendId> backends)
// Load it into the runtime. It should pass.
NetworkId netId;
std::string ignoredErrorMessage;
- INetworkProperties networkProperties(true, true);
+
+ INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc);
+
BOOST_TEST(runtime->LoadNetwork(netId, std::move(optNet),ignoredErrorMessage, networkProperties)
== Status::Success);
@@ -694,7 +698,7 @@ inline void ExportOutputWithSeveralOutputSlotConnectionsTest(std::vector<Backend
NetworkId netId;
std::string ignoredErrorMessage;
// Enable Importing
- INetworkProperties networkProperties(true, true);
+ INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc);
runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties);
// Creates structures for input & output
diff --git a/src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp b/src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp
index 16b10c88ac..b20ff4f142 100644
--- a/src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp
+++ b/src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp
@@ -42,7 +42,7 @@ void AsyncThreadedEndToEndTestImpl(INetworkPtr network,
// Creates AsyncNetwork
NetworkId networkId = 0;
std::string errorMessage;
- const INetworkProperties networkProperties(false, false, true);
+ const INetworkProperties networkProperties(true, MemorySource::Undefined, MemorySource::Undefined);
runtime->LoadNetwork(networkId, std::move(optNet), errorMessage, networkProperties);
std::vector<InputTensors> inputTensorsVec;
@@ -134,7 +134,7 @@ void AsyncEndToEndTestImpl(INetworkPtr network,
// Creates AsyncNetwork
NetworkId networkId = 0;
std::string errorMessage;
- const INetworkProperties networkProperties(false, false, true);
+ const INetworkProperties networkProperties(true, MemorySource::Undefined, MemorySource::Undefined);
runtime->LoadNetwork(networkId, std::move(optNet), errorMessage, networkProperties);
InputTensors inputTensors;