diff options
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r-- | src/backends/backendsCommon/Workload.hpp | 5 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 53 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.hpp | 7 |
3 files changed, 65 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/Workload.hpp b/src/backends/backendsCommon/Workload.hpp index 4d14adbf54..7fb26f8b56 100644 --- a/src/backends/backendsCommon/Workload.hpp +++ b/src/backends/backendsCommon/Workload.hpp @@ -187,4 +187,9 @@ using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor, armnn::DataType::Float32, armnn::DataType::Float16>; +template <typename QueueDescriptor> +using Uint8ToFloat32Workload = MultiTypedWorkload<QueueDescriptor, + armnn::DataType::QuantisedAsymm8, + armnn::DataType::Float32>; + } //namespace armnn diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 9714b02a80..b31d626550 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1082,6 +1082,59 @@ void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateTensorNumDimensions(output, "GatherQueueDescriptor", outputDim, "output"); } +void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + ValidateTwoInputs(workloadInfo, "DetectionPostProcessQueueDescriptor"); + + if (workloadInfo.m_OutputTensorInfos.size() != 4) + { + throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Requires exactly four outputs. " + + to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided."); + } + + if (m_Anchors == nullptr) + { + throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Anchors tensor descriptor is missing."); + } + + const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1]; + const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo(); + const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0]; + const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[1]; + const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[2]; + const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3]; + + ValidateTensorNumDimensions(boxEncodingsInfo, "DetectionPostProcessQueueDescriptor", 3, "box encodings"); + ValidateTensorNumDimensions(scoresInfo, "DetectionPostProcessQueueDescriptor", 3, "scores"); + ValidateTensorNumDimensions(anchorsInfo, "DetectionPostProcessQueueDescriptor", 2, "anchors"); + + ValidateTensorNumDimensions(detectionBoxesInfo, "DetectionPostProcessQueueDescriptor", 3, "detection boxes"); + ValidateTensorNumDimensions(detectionScoresInfo, "DetectionPostProcessQueueDescriptor", 2, "detection scores"); + ValidateTensorNumDimensions(detectionClassesInfo, "DetectionPostProcessQueueDescriptor", 2, "detection classes"); + ValidateTensorNumDimensions(numDetectionsInfo, "DetectionPostProcessQueueDescriptor", 1, "num detections"); + + ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, + "DetectionPostProcessQueueDescriptor", "detection boxes"); + ValidateTensorDataType(detectionScoresInfo, DataType::Float32, + "DetectionPostProcessQueueDescriptor", "detection scores"); + ValidateTensorDataType(detectionClassesInfo, DataType::Float32, + "DetectionPostProcessQueueDescriptor", "detection classes"); + ValidateTensorDataType(numDetectionsInfo, DataType::Float32, + "DetectionPostProcessQueueDescriptor", "num detections"); + + if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f) + { + throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Intersection over union threshold " + "must be positive and less than or equal to 1."); + } + if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1) + { + throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Number of classes with background " + "should be equal to number of classes + 1."); + } +} + void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { // This is internally generated so it should not need validation. diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp index e44eba71af..09f56479cd 100644 --- a/src/backends/backendsCommon/WorkloadData.hpp +++ b/src/backends/backendsCommon/WorkloadData.hpp @@ -171,6 +171,13 @@ struct DepthwiseConvolution2dQueueDescriptor : QueueDescriptorWithParameters<Dep struct DetectionPostProcessQueueDescriptor : QueueDescriptorWithParameters<DetectionPostProcessDescriptor> { + DetectionPostProcessQueueDescriptor() + : m_Anchors(nullptr) + { + } + + const ConstCpuTensorHandle* m_Anchors; + void Validate(const WorkloadInfo& workloadInfo) const; }; |