aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFerran Balaguer <ferran.balaguer@arm.com>2019-08-14 12:11:27 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2019-08-14 13:09:48 +0000
commit97520101bb98a30315dd6e31198b08ba050b58c0 (patch)
tree7ddde38e3feb3858208063afabbb380722ba66e1
parent4f77ac2685a4130b07a4d28420929d883b251571 (diff)
downloadarmnn-97520101bb98a30315dd6e31198b08ba050b58c0.tar.gz
IVGCVSW-3636 Fix Graph and WorkloaData to support backend Import functionality
Signed-off-by: Ferran Balaguer <ferran.balaguer@arm.com> Change-Id: I634aa3b1d609ca33b196fd68ce7fb7881be73e6e
-rw-r--r--src/armnn/Graph.cpp19
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp16
2 files changed, 13 insertions, 22 deletions
diff --git a/src/armnn/Graph.cpp b/src/armnn/Graph.cpp
index 6212c49eba..f0ef0e18f7 100644
--- a/src/armnn/Graph.cpp
+++ b/src/armnn/Graph.cpp
@@ -344,17 +344,24 @@ void Graph::AddCompatibilityLayers(std::map<BackendId, std::unique_ptr<IBackendI
{
auto srcPref = srcOutputSlot.GetTensorHandleFactoryId();
auto srcFactory = registry.GetFactory(srcPref);
- bool canExportImport = (factory->GetImportFlags() & srcFactory->GetExportFlags()) != 0;
- if (factory->SupportsMapUnmap() || canExportImport)
+ if (srcFactory)
{
- compOutputSlot.SetTensorHandleFactory(preference);
- found = true;
- break;
+ bool canExportImport =
+ (factory->GetImportFlags() & srcFactory->GetExportFlags()) != 0;
+ if (factory->SupportsMapUnmap() || canExportImport)
+ {
+ compOutputSlot.SetTensorHandleFactory(preference);
+ found = true;
+ break;
+ }
}
}
}
- BOOST_ASSERT_MSG(found, "Could not find a valid TensorHandle for compatibilty layer");
+ if (!found)
+ {
+ compOutputSlot.SetTensorHandleFactory(ITensorHandleFactory::LegacyFactoryId);
+ }
}
else
{
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 1c607da707..4b0b84a73d 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -417,22 +417,6 @@ void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
- if (workloadInfo.m_InputTensorInfos.size() != 1)
- {
- throw InvalidArgumentException(boost::str(
- boost::format("Number of input infos (%1%) is not 1.")
- % workloadInfo.m_InputTensorInfos.size()));
-
- }
-
- if (workloadInfo.m_OutputTensorInfos.size() != 0)
- {
- throw InvalidArgumentException(boost::str(
- boost::format("Number of output infos (%1%) is not 0.")
- % workloadInfo.m_InputTensorInfos.size()));
-
- }
-
if (m_Inputs.size() != 1)
{
throw InvalidArgumentException(boost::str(