diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 53 |
1 files changed, 53 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 9714b02a80..b31d626550 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1082,6 +1082,59 @@ void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateTensorNumDimensions(output, "GatherQueueDescriptor", outputDim, "output"); } +void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + ValidateTwoInputs(workloadInfo, "DetectionPostProcessQueueDescriptor"); + + if (workloadInfo.m_OutputTensorInfos.size() != 4) + { + throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Requires exactly four outputs. " + + to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided."); + } + + if (m_Anchors == nullptr) + { + throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Anchors tensor descriptor is missing."); + } + + const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1]; + const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo(); + const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0]; + const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[1]; + const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[2]; + const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3]; + + ValidateTensorNumDimensions(boxEncodingsInfo, "DetectionPostProcessQueueDescriptor", 3, "box encodings"); + ValidateTensorNumDimensions(scoresInfo, "DetectionPostProcessQueueDescriptor", 3, "scores"); + ValidateTensorNumDimensions(anchorsInfo, "DetectionPostProcessQueueDescriptor", 2, "anchors"); + + ValidateTensorNumDimensions(detectionBoxesInfo, "DetectionPostProcessQueueDescriptor", 3, "detection boxes"); + ValidateTensorNumDimensions(detectionScoresInfo, "DetectionPostProcessQueueDescriptor", 2, "detection scores"); + ValidateTensorNumDimensions(detectionClassesInfo, "DetectionPostProcessQueueDescriptor", 2, "detection classes"); + ValidateTensorNumDimensions(numDetectionsInfo, "DetectionPostProcessQueueDescriptor", 1, "num detections"); + + ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, + "DetectionPostProcessQueueDescriptor", "detection boxes"); + ValidateTensorDataType(detectionScoresInfo, DataType::Float32, + "DetectionPostProcessQueueDescriptor", "detection scores"); + ValidateTensorDataType(detectionClassesInfo, DataType::Float32, + "DetectionPostProcessQueueDescriptor", "detection classes"); + ValidateTensorDataType(numDetectionsInfo, DataType::Float32, + "DetectionPostProcessQueueDescriptor", "num detections"); + + if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f) + { + throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Intersection over union threshold " + "must be positive and less than or equal to 1."); + } + if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1) + { + throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Number of classes with background " + "should be equal to number of classes + 1."); + } +} + void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { // This is internally generated so it should not need validation. |