aboutsummaryrefslogtreecommitdiff
path: root/src/armnn
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn')
-rw-r--r--src/armnn/Graph.cpp5
-rw-r--r--src/armnn/LoadedNetwork.cpp37
-rw-r--r--src/armnn/test/TensorHandleStrategyTest.cpp4
3 files changed, 25 insertions, 21 deletions
diff --git a/src/armnn/Graph.cpp b/src/armnn/Graph.cpp
index f0ef0e18f7..5e2acd55a3 100644
--- a/src/armnn/Graph.cpp
+++ b/src/armnn/Graph.cpp
@@ -335,7 +335,6 @@ void Graph::AddCompatibilityLayers(std::map<BackendId, std::unique_ptr<IBackendI
auto backend = backendIt->second.get();
auto tensorHandleFactoryIds = backend->GetHandleFactoryPreferences();
bool found = false;
- boost::ignore_unused(found);
for (auto preference : tensorHandleFactoryIds)
{
@@ -344,10 +343,12 @@ void Graph::AddCompatibilityLayers(std::map<BackendId, std::unique_ptr<IBackendI
{
auto srcPref = srcOutputSlot.GetTensorHandleFactoryId();
auto srcFactory = registry.GetFactory(srcPref);
+
if (srcFactory)
{
bool canExportImport =
- (factory->GetImportFlags() & srcFactory->GetExportFlags()) != 0;
+ (factory->GetImportFlags() & srcFactory->GetExportFlags()) != 0;
+
if (factory->SupportsMapUnmap() || canExportImport)
{
compOutputSlot.SetTensorHandleFactory(preference);
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index f5f79f3940..5b64085869 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -444,26 +444,29 @@ void LoadedNetwork::EnqueueOutput(const BindableLayer& layer, ITensorHandle* ten
// b) The tensor has zero padding
// c) There is only one connection to the OutputSlot and it is to an OutputLayer.
// d) The output pointer is allocated via malloc. (Other types will be supported in a later release)
- if (layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetNumConnections() == 1)
+ if (layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetOwningLayer().GetType() != LayerType::Input)
{
- MemorySourceFlags importFlags = inputTensorHandle->GetImportFlags();
- if (CheckFlag(importFlags, MemorySource::Malloc))
+ if (layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetNumConnections() == 1)
{
- void* mem = tensorHandle->Map(false);
- bool importOk = inputTensorHandle->Import(mem, MemorySource::Malloc);
- tensorHandle->Unmap();
-
- if (importOk)
+ MemorySourceFlags importFlags = inputTensorHandle->GetImportFlags();
+ if (CheckFlag(importFlags, MemorySource::Malloc))
{
- // Insert synchronization workload
- MemSyncQueueDescriptor syncDesc;
- syncDesc.m_Inputs.push_back(inputTensorHandle);
- info.m_InputTensorInfos.push_back(inputTensorInfo);
- auto syncWorkload = std::make_unique<SyncMemGenericWorkload>(syncDesc, info);
- BOOST_ASSERT_MSG(syncWorkload, "No sync workload created");
- m_OutputQueue.push_back(move(syncWorkload));
-
- return; //No need to add the output workload below
+ void *mem = tensorHandle->Map(false);
+ bool importOk = inputTensorHandle->Import(mem, MemorySource::Malloc);
+ tensorHandle->Unmap();
+
+ if (importOk)
+ {
+ // Insert synchronization workload
+ MemSyncQueueDescriptor syncDesc;
+ syncDesc.m_Inputs.push_back(inputTensorHandle);
+ info.m_InputTensorInfos.push_back(inputTensorInfo);
+ auto syncWorkload = std::make_unique<SyncMemGenericWorkload>(syncDesc, info);
+ BOOST_ASSERT_MSG(syncWorkload, "No sync workload created");
+ m_OutputQueue.push_back(move(syncWorkload));
+
+ return; //No need to add the output workload below
+ }
}
}
}
diff --git a/src/armnn/test/TensorHandleStrategyTest.cpp b/src/armnn/test/TensorHandleStrategyTest.cpp
index 2056b6fca1..3c53b13e1a 100644
--- a/src/armnn/test/TensorHandleStrategyTest.cpp
+++ b/src/armnn/test/TensorHandleStrategyTest.cpp
@@ -56,7 +56,7 @@ public:
return nullptr;
}
- const FactoryId GetId() const override { return m_Id; }
+ const FactoryId& GetId() const override { return m_Id; }
bool SupportsSubTensors() const override { return true; }
@@ -94,7 +94,7 @@ public:
return nullptr;
}
- const FactoryId GetId() const override { return m_Id; }
+ const FactoryId& GetId() const override { return m_Id; }
bool SupportsSubTensors() const override { return true; }