diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2019-01-31 15:31:54 +0000 |
---|---|---|
committer | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2019-02-04 10:57:48 +0000 |
commit | bc67cef3e3dc9e7fe9c4331495009eda48c89527 (patch) | |
tree | 6a15af84fbc5989d25213790554acbb46cda5165 /src/backends/backendsCommon | |
parent | c981df3bb24df1f98c233d885e73a2ea5c6d3449 (diff) | |
download | armnn-bc67cef3e3dc9e7fe9c4331495009eda48c89527.tar.gz |
IVGCVSW-2557 Ref Workload Implementation for Detection PostProcess
* implementation of DetectionPostProcessQueueDescriptor validate
* add Uint8ToFloat32Workload
* add implementation of Detection PostProcess functionalities
* add ref workload implemenentation for float and uint8
* add layer support for Detection PostProcess in ref
* unit tests
Change-Id: I650461f49edbb3c533d68ef8700377af51bc3592
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r-- | src/backends/backendsCommon/Workload.hpp | 5 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 53 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.hpp | 7 |
3 files changed, 65 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/Workload.hpp b/src/backends/backendsCommon/Workload.hpp index 4d14adbf54..7fb26f8b56 100644 --- a/src/backends/backendsCommon/Workload.hpp +++ b/src/backends/backendsCommon/Workload.hpp @@ -187,4 +187,9 @@ using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor, armnn::DataType::Float32, armnn::DataType::Float16>; +template <typename QueueDescriptor> +using Uint8ToFloat32Workload = MultiTypedWorkload<QueueDescriptor, + armnn::DataType::QuantisedAsymm8, + armnn::DataType::Float32>; + } //namespace armnn diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 9714b02a80..b31d626550 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1082,6 +1082,59 @@ void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateTensorNumDimensions(output, "GatherQueueDescriptor", outputDim, "output"); } +void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + ValidateTwoInputs(workloadInfo, "DetectionPostProcessQueueDescriptor"); + + if (workloadInfo.m_OutputTensorInfos.size() != 4) + { + throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Requires exactly four outputs. " + + to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided."); + } + + if (m_Anchors == nullptr) + { + throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Anchors tensor descriptor is missing."); + } + + const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1]; + const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo(); + const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0]; + const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[1]; + const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[2]; + const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3]; + + ValidateTensorNumDimensions(boxEncodingsInfo, "DetectionPostProcessQueueDescriptor", 3, "box encodings"); + ValidateTensorNumDimensions(scoresInfo, "DetectionPostProcessQueueDescriptor", 3, "scores"); + ValidateTensorNumDimensions(anchorsInfo, "DetectionPostProcessQueueDescriptor", 2, "anchors"); + + ValidateTensorNumDimensions(detectionBoxesInfo, "DetectionPostProcessQueueDescriptor", 3, "detection boxes"); + ValidateTensorNumDimensions(detectionScoresInfo, "DetectionPostProcessQueueDescriptor", 2, "detection scores"); + ValidateTensorNumDimensions(detectionClassesInfo, "DetectionPostProcessQueueDescriptor", 2, "detection classes"); + ValidateTensorNumDimensions(numDetectionsInfo, "DetectionPostProcessQueueDescriptor", 1, "num detections"); + + ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, + "DetectionPostProcessQueueDescriptor", "detection boxes"); + ValidateTensorDataType(detectionScoresInfo, DataType::Float32, + "DetectionPostProcessQueueDescriptor", "detection scores"); + ValidateTensorDataType(detectionClassesInfo, DataType::Float32, + "DetectionPostProcessQueueDescriptor", "detection classes"); + ValidateTensorDataType(numDetectionsInfo, DataType::Float32, + "DetectionPostProcessQueueDescriptor", "num detections"); + + if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f) + { + throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Intersection over union threshold " + "must be positive and less than or equal to 1."); + } + if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1) + { + throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Number of classes with background " + "should be equal to number of classes + 1."); + } +} + void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { // This is internally generated so it should not need validation. diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp index e44eba71af..09f56479cd 100644 --- a/src/backends/backendsCommon/WorkloadData.hpp +++ b/src/backends/backendsCommon/WorkloadData.hpp @@ -171,6 +171,13 @@ struct DepthwiseConvolution2dQueueDescriptor : QueueDescriptorWithParameters<Dep struct DetectionPostProcessQueueDescriptor : QueueDescriptorWithParameters<DetectionPostProcessDescriptor> { + DetectionPostProcessQueueDescriptor() + : m_Anchors(nullptr) + { + } + + const ConstCpuTensorHandle* m_Anchors; + void Validate(const WorkloadInfo& workloadInfo) const; }; |