diff options
author | Derek Lamberti <derek.lamberti@arm.com> | 2019-08-01 15:56:25 +0100 |
---|---|---|
committer | Áron Virginás-Tar <aron.virginas-tar@arm.com> | 2019-08-05 13:51:42 +0000 |
commit | f674aa0fd2809126debdaaeb8067067790d86907 (patch) | |
tree | d86d0261c7a25149217918986043c76d0823ee44 /src/backends/backendsCommon/WorkloadData.cpp | |
parent | 737d9ff58b348b11234b6c2363390607d576177d (diff) | |
download | armnn-f674aa0fd2809126debdaaeb8067067790d86907.tar.gz |
IVGCVSW-3277 Mem export/import suppor for Tensors
* Rename MemoryStrategy to EdgeStrategy
* Add MemImportLayer
* Import memory rather than copy when possible
Change-Id: I1d3a9414f2cbe517dc2aae9bbd4fdd92712b38ef
Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 103 |
1 files changed, 103 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index a4d35827fa..1c607da707 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -351,6 +351,109 @@ void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const } } +//--------------------------------------------------------------- +void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1); + + if (workloadInfo.m_InputTensorInfos.size() != 1) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of input infos (%1%) is not 1.") + % workloadInfo.m_InputTensorInfos.size())); + + } + + if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size()) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of input infos (%1%) does not match the number of output infos (%2%)") + % workloadInfo.m_InputTensorInfos.size() % workloadInfo.m_OutputTensorInfos.size())); + } + + for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i) + { + if (workloadInfo.m_InputTensorInfos[i].GetNumElements() != + workloadInfo.m_OutputTensorInfos[i].GetNumElements()) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of elements for tensor input and output %1% does not match") + % i )); + } + } + + if (m_Inputs.size() != 1) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of inputs (%1%) is not 1.") + % m_Inputs.size())); + } + + if (m_Inputs.size() != m_Outputs.size()) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of inputs (%1%) does not match the number of outputs (%2%)") + % m_Inputs.size() % m_Outputs.size())); + } + + for (unsigned int i = 0; i < m_Inputs.size(); ++i) + { + if (!m_Inputs[i]) + { + throw InvalidArgumentException(boost::str(boost::format("Invalid null input %1%") % i)); + } + + if (!m_Outputs[i]) + { + throw InvalidArgumentException(boost::str(boost::format("Invalid null output %1%") % i)); + } + } +} + +//--------------------------------------------------------------- +void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1); + + if (workloadInfo.m_InputTensorInfos.size() != 1) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of input infos (%1%) is not 1.") + % workloadInfo.m_InputTensorInfos.size())); + + } + + if (workloadInfo.m_OutputTensorInfos.size() != 0) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of output infos (%1%) is not 0.") + % workloadInfo.m_InputTensorInfos.size())); + + } + + if (m_Inputs.size() != 1) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of inputs (%1%) is not 1.") + % m_Inputs.size())); + } + + if (m_Outputs.size() != 0) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of outputs (%1%) is not 0.") + % m_Inputs.size() % m_Outputs.size())); + } + + if (!m_Inputs[0]) + { + throw InvalidArgumentException(boost::str(boost::format("Invalid null input 0"))); + } +} + +//--------------------------------------------------------------- void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { const std::string descriptorName{"ActivationQueueDescriptor"}; |