From bc67cef3e3dc9e7fe9c4331495009eda48c89527 Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Thu, 31 Jan 2019 15:31:54 +0000 Subject: IVGCVSW-2557 Ref Workload Implementation for Detection PostProcess * implementation of DetectionPostProcessQueueDescriptor validate * add Uint8ToFloat32Workload * add implementation of Detection PostProcess functionalities * add ref workload implemenentation for float and uint8 * add layer support for Detection PostProcess in ref * unit tests Change-Id: I650461f49edbb3c533d68ef8700377af51bc3592 --- src/backends/backendsCommon/Workload.hpp | 5 +++ src/backends/backendsCommon/WorkloadData.cpp | 53 ++++++++++++++++++++++++++++ src/backends/backendsCommon/WorkloadData.hpp | 7 ++++ 3 files changed, 65 insertions(+) (limited to 'src/backends/backendsCommon') diff --git a/src/backends/backendsCommon/Workload.hpp b/src/backends/backendsCommon/Workload.hpp index 4d14adbf54..7fb26f8b56 100644 --- a/src/backends/backendsCommon/Workload.hpp +++ b/src/backends/backendsCommon/Workload.hpp @@ -187,4 +187,9 @@ using Float32ToFloat16Workload = MultiTypedWorkload; +template +using Uint8ToFloat32Workload = MultiTypedWorkload; + } //namespace armnn diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 9714b02a80..b31d626550 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1082,6 +1082,59 @@ void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateTensorNumDimensions(output, "GatherQueueDescriptor", outputDim, "output"); } +void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + ValidateTwoInputs(workloadInfo, "DetectionPostProcessQueueDescriptor"); + + if (workloadInfo.m_OutputTensorInfos.size() != 4) + { + throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Requires exactly four outputs. " + + to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided."); + } + + if (m_Anchors == nullptr) + { + throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Anchors tensor descriptor is missing."); + } + + const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1]; + const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo(); + const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0]; + const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[1]; + const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[2]; + const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3]; + + ValidateTensorNumDimensions(boxEncodingsInfo, "DetectionPostProcessQueueDescriptor", 3, "box encodings"); + ValidateTensorNumDimensions(scoresInfo, "DetectionPostProcessQueueDescriptor", 3, "scores"); + ValidateTensorNumDimensions(anchorsInfo, "DetectionPostProcessQueueDescriptor", 2, "anchors"); + + ValidateTensorNumDimensions(detectionBoxesInfo, "DetectionPostProcessQueueDescriptor", 3, "detection boxes"); + ValidateTensorNumDimensions(detectionScoresInfo, "DetectionPostProcessQueueDescriptor", 2, "detection scores"); + ValidateTensorNumDimensions(detectionClassesInfo, "DetectionPostProcessQueueDescriptor", 2, "detection classes"); + ValidateTensorNumDimensions(numDetectionsInfo, "DetectionPostProcessQueueDescriptor", 1, "num detections"); + + ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, + "DetectionPostProcessQueueDescriptor", "detection boxes"); + ValidateTensorDataType(detectionScoresInfo, DataType::Float32, + "DetectionPostProcessQueueDescriptor", "detection scores"); + ValidateTensorDataType(detectionClassesInfo, DataType::Float32, + "DetectionPostProcessQueueDescriptor", "detection classes"); + ValidateTensorDataType(numDetectionsInfo, DataType::Float32, + "DetectionPostProcessQueueDescriptor", "num detections"); + + if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f) + { + throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Intersection over union threshold " + "must be positive and less than or equal to 1."); + } + if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1) + { + throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Number of classes with background " + "should be equal to number of classes + 1."); + } +} + void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { // This is internally generated so it should not need validation. diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp index e44eba71af..09f56479cd 100644 --- a/src/backends/backendsCommon/WorkloadData.hpp +++ b/src/backends/backendsCommon/WorkloadData.hpp @@ -171,6 +171,13 @@ struct DepthwiseConvolution2dQueueDescriptor : QueueDescriptorWithParameters { + DetectionPostProcessQueueDescriptor() + : m_Anchors(nullptr) + { + } + + const ConstCpuTensorHandle* m_Anchors; + void Validate(const WorkloadInfo& workloadInfo) const; }; -- cgit v1.2.1