From f674aa0fd2809126debdaaeb8067067790d86907 Mon Sep 17 00:00:00 2001 From: Derek Lamberti Date: Thu, 1 Aug 2019 15:56:25 +0100 Subject: IVGCVSW-3277 Mem export/import suppor for Tensors * Rename MemoryStrategy to EdgeStrategy * Add MemImportLayer * Import memory rather than copy when possible Change-Id: I1d3a9414f2cbe517dc2aae9bbd4fdd92712b38ef Signed-off-by: Derek Lamberti --- src/backends/backendsCommon/WorkloadData.cpp | 103 +++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) (limited to 'src/backends/backendsCommon/WorkloadData.cpp') diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index a4d35827fa..1c607da707 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -351,6 +351,109 @@ void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const } } +//--------------------------------------------------------------- +void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1); + + if (workloadInfo.m_InputTensorInfos.size() != 1) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of input infos (%1%) is not 1.") + % workloadInfo.m_InputTensorInfos.size())); + + } + + if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size()) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of input infos (%1%) does not match the number of output infos (%2%)") + % workloadInfo.m_InputTensorInfos.size() % workloadInfo.m_OutputTensorInfos.size())); + } + + for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i) + { + if (workloadInfo.m_InputTensorInfos[i].GetNumElements() != + workloadInfo.m_OutputTensorInfos[i].GetNumElements()) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of elements for tensor input and output %1% does not match") + % i )); + } + } + + if (m_Inputs.size() != 1) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of inputs (%1%) is not 1.") + % m_Inputs.size())); + } + + if (m_Inputs.size() != m_Outputs.size()) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of inputs (%1%) does not match the number of outputs (%2%)") + % m_Inputs.size() % m_Outputs.size())); + } + + for (unsigned int i = 0; i < m_Inputs.size(); ++i) + { + if (!m_Inputs[i]) + { + throw InvalidArgumentException(boost::str(boost::format("Invalid null input %1%") % i)); + } + + if (!m_Outputs[i]) + { + throw InvalidArgumentException(boost::str(boost::format("Invalid null output %1%") % i)); + } + } +} + +//--------------------------------------------------------------- +void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1); + + if (workloadInfo.m_InputTensorInfos.size() != 1) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of input infos (%1%) is not 1.") + % workloadInfo.m_InputTensorInfos.size())); + + } + + if (workloadInfo.m_OutputTensorInfos.size() != 0) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of output infos (%1%) is not 0.") + % workloadInfo.m_InputTensorInfos.size())); + + } + + if (m_Inputs.size() != 1) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of inputs (%1%) is not 1.") + % m_Inputs.size())); + } + + if (m_Outputs.size() != 0) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of outputs (%1%) is not 0.") + % m_Inputs.size() % m_Outputs.size())); + } + + if (!m_Inputs[0]) + { + throw InvalidArgumentException(boost::str(boost::format("Invalid null input 0"))); + } +} + +//--------------------------------------------------------------- void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { const std::string descriptorName{"ActivationQueueDescriptor"}; -- cgit v1.2.1