aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-06-03 17:10:02 +0100
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>2019-06-05 15:06:39 +0000
commit6331f91a4a1cb1ad16c569d98bb9ddf704788464 (patch)
tree338cce081966bfb42f635b6febd68642d492b9f8 /src/backends/backendsCommon/WorkloadData.cpp
parent18f2d1ccf9e743e61ed3733ae5a38f796a759db8 (diff)
downloadarmnn-6331f91a4a1cb1ad16c569d98bb9ddf704788464.tar.gz
IVGCVSW-2971 Support QSymm16 for DetectionPostProcess workloads
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> Change-Id: I8af45afe851a9ccbf8bce54727147fcd52ac9a1f
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp66
1 files changed, 38 insertions, 28 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index a373f55d3e..d0aaf1db38 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1459,53 +1459,63 @@ void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateNumInputs(workloadInfo, "DetectionPostProcessQueueDescriptor", 2);
+ const std::string& descriptorName = " DetectionPostProcessQueueDescriptor";
+ ValidateNumInputs(workloadInfo, descriptorName, 2);
if (workloadInfo.m_OutputTensorInfos.size() != 4)
{
- throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Requires exactly four outputs. " +
+ throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
}
if (m_Anchors == nullptr)
{
- throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Anchors tensor descriptor is missing.");
+ throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
}
const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
- const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
- const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
- const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
+ const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
+ const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
+
+ const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
- const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
- const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
-
- ValidateTensorNumDimensions(boxEncodingsInfo, "DetectionPostProcessQueueDescriptor", 3, "box encodings");
- ValidateTensorNumDimensions(scoresInfo, "DetectionPostProcessQueueDescriptor", 3, "scores");
- ValidateTensorNumDimensions(anchorsInfo, "DetectionPostProcessQueueDescriptor", 2, "anchors");
-
- ValidateTensorNumDimensions(detectionBoxesInfo, "DetectionPostProcessQueueDescriptor", 3, "detection boxes");
- ValidateTensorNumDimensions(detectionScoresInfo, "DetectionPostProcessQueueDescriptor", 2, "detection scores");
- ValidateTensorNumDimensions(detectionClassesInfo, "DetectionPostProcessQueueDescriptor", 2, "detection classes");
- ValidateTensorNumDimensions(numDetectionsInfo, "DetectionPostProcessQueueDescriptor", 1, "num detections");
-
- ValidateTensorDataType(detectionBoxesInfo, DataType::Float32,
- "DetectionPostProcessQueueDescriptor", "detection boxes");
- ValidateTensorDataType(detectionScoresInfo, DataType::Float32,
- "DetectionPostProcessQueueDescriptor", "detection scores");
- ValidateTensorDataType(detectionClassesInfo, DataType::Float32,
- "DetectionPostProcessQueueDescriptor", "detection classes");
- ValidateTensorDataType(numDetectionsInfo, DataType::Float32,
- "DetectionPostProcessQueueDescriptor", "num detections");
+ const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
+ const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
+
+ ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
+ ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
+ ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
+
+ const std::vector<DataType> supportedInputTypes =
+ {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
+ ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
+ ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
+
+ ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
+ ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
+ ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
+ ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
+
+ // NOTE: Output is always Float32 regardless of input type
+ ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
+ ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
+ ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
+ ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
{
- throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Intersection over union threshold "
+ throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
"must be positive and less than or equal to 1.");
}
if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
{
- throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Number of classes with background "
+ throw InvalidArgumentException(descriptorName + ": Number of classes with background "
"should be equal to number of classes + 1.");
}
}