aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp53
1 files changed, 53 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 9714b02a80..b31d626550 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1082,6 +1082,59 @@ void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateTensorNumDimensions(output, "GatherQueueDescriptor", outputDim, "output");
}
+void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ ValidateTwoInputs(workloadInfo, "DetectionPostProcessQueueDescriptor");
+
+ if (workloadInfo.m_OutputTensorInfos.size() != 4)
+ {
+ throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Requires exactly four outputs. " +
+ to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
+ }
+
+ if (m_Anchors == nullptr)
+ {
+ throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Anchors tensor descriptor is missing.");
+ }
+
+ const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
+ const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
+ const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
+ const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
+ const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[1];
+ const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[2];
+ const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
+
+ ValidateTensorNumDimensions(boxEncodingsInfo, "DetectionPostProcessQueueDescriptor", 3, "box encodings");
+ ValidateTensorNumDimensions(scoresInfo, "DetectionPostProcessQueueDescriptor", 3, "scores");
+ ValidateTensorNumDimensions(anchorsInfo, "DetectionPostProcessQueueDescriptor", 2, "anchors");
+
+ ValidateTensorNumDimensions(detectionBoxesInfo, "DetectionPostProcessQueueDescriptor", 3, "detection boxes");
+ ValidateTensorNumDimensions(detectionScoresInfo, "DetectionPostProcessQueueDescriptor", 2, "detection scores");
+ ValidateTensorNumDimensions(detectionClassesInfo, "DetectionPostProcessQueueDescriptor", 2, "detection classes");
+ ValidateTensorNumDimensions(numDetectionsInfo, "DetectionPostProcessQueueDescriptor", 1, "num detections");
+
+ ValidateTensorDataType(detectionBoxesInfo, DataType::Float32,
+ "DetectionPostProcessQueueDescriptor", "detection boxes");
+ ValidateTensorDataType(detectionScoresInfo, DataType::Float32,
+ "DetectionPostProcessQueueDescriptor", "detection scores");
+ ValidateTensorDataType(detectionClassesInfo, DataType::Float32,
+ "DetectionPostProcessQueueDescriptor", "detection classes");
+ ValidateTensorDataType(numDetectionsInfo, DataType::Float32,
+ "DetectionPostProcessQueueDescriptor", "num detections");
+
+ if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
+ {
+ throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Intersection over union threshold "
+ "must be positive and less than or equal to 1.");
+ }
+ if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
+ {
+ throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Number of classes with background "
+ "should be equal to number of classes + 1.");
+ }
+}
+
void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
// This is internally generated so it should not need validation.