diff options
author | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-06-03 17:10:02 +0100 |
---|---|---|
committer | Áron Virginás-Tar <aron.virginas-tar@arm.com> | 2019-06-05 15:06:39 +0000 |
commit | 6331f91a4a1cb1ad16c569d98bb9ddf704788464 (patch) | |
tree | 338cce081966bfb42f635b6febd68642d492b9f8 /src/backends/backendsCommon/WorkloadData.cpp | |
parent | 18f2d1ccf9e743e61ed3733ae5a38f796a759db8 (diff) | |
download | armnn-6331f91a4a1cb1ad16c569d98bb9ddf704788464.tar.gz |
IVGCVSW-2971 Support QSymm16 for DetectionPostProcess workloads
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: I8af45afe851a9ccbf8bce54727147fcd52ac9a1f
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 66 |
1 files changed, 38 insertions, 28 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index a373f55d3e..d0aaf1db38 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1459,53 +1459,63 @@ void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateNumInputs(workloadInfo, "DetectionPostProcessQueueDescriptor", 2); + const std::string& descriptorName = " DetectionPostProcessQueueDescriptor"; + ValidateNumInputs(workloadInfo, descriptorName, 2); if (workloadInfo.m_OutputTensorInfos.size() != 4) { - throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Requires exactly four outputs. " + + throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " + to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided."); } if (m_Anchors == nullptr) { - throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Anchors tensor descriptor is missing."); + throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing."); } const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0]; - const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1]; - const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo(); - const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0]; + const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1]; + const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo(); + + const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0]; const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1]; - const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2]; - const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3]; - - ValidateTensorNumDimensions(boxEncodingsInfo, "DetectionPostProcessQueueDescriptor", 3, "box encodings"); - ValidateTensorNumDimensions(scoresInfo, "DetectionPostProcessQueueDescriptor", 3, "scores"); - ValidateTensorNumDimensions(anchorsInfo, "DetectionPostProcessQueueDescriptor", 2, "anchors"); - - ValidateTensorNumDimensions(detectionBoxesInfo, "DetectionPostProcessQueueDescriptor", 3, "detection boxes"); - ValidateTensorNumDimensions(detectionScoresInfo, "DetectionPostProcessQueueDescriptor", 2, "detection scores"); - ValidateTensorNumDimensions(detectionClassesInfo, "DetectionPostProcessQueueDescriptor", 2, "detection classes"); - ValidateTensorNumDimensions(numDetectionsInfo, "DetectionPostProcessQueueDescriptor", 1, "num detections"); - - ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, - "DetectionPostProcessQueueDescriptor", "detection boxes"); - ValidateTensorDataType(detectionScoresInfo, DataType::Float32, - "DetectionPostProcessQueueDescriptor", "detection scores"); - ValidateTensorDataType(detectionClassesInfo, DataType::Float32, - "DetectionPostProcessQueueDescriptor", "detection classes"); - ValidateTensorDataType(numDetectionsInfo, DataType::Float32, - "DetectionPostProcessQueueDescriptor", "num detections"); + const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2]; + const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3]; + + ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings"); + ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores"); + ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors"); + + const std::vector<DataType> supportedInputTypes = + { + DataType::Float32, + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 + }; + + ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName); + ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName); + ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName); + + ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes"); + ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores"); + ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes"); + ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections"); + + // NOTE: Output is always Float32 regardless of input type + ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes"); + ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores"); + ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes"); + ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections"); if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f) { - throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Intersection over union threshold " + throw InvalidArgumentException(descriptorName + ": Intersection over union threshold " "must be positive and less than or equal to 1."); } if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1) { - throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Number of classes with background " + throw InvalidArgumentException(descriptorName + ": Number of classes with background " "should be equal to number of classes + 1."); } } |