diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 26 |
1 files changed, 21 insertions, 5 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 20e125293a..098e8575b1 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -948,16 +948,32 @@ void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateNumInputs(workloadInfo, "L2NormalizationQueueDescriptor", 1); - ValidateNumOutputs(workloadInfo, "L2NormalizationQueueDescriptor", 1); + const std::string& descriptorName = "L2NormalizationQueueDescriptor"; - ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "L2NormalizationQueueDescriptor", 4, "input"); - ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "L2NormalizationQueueDescriptor", 4, "output"); + ValidateNumInputs(workloadInfo, descriptorName, 1); + ValidateNumOutputs(workloadInfo, descriptorName, 1); + + ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], descriptorName, 4, "input"); + ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], descriptorName, 4, "output"); ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_OutputTensorInfos[0], - "L2NormalizationQueueDescriptor", + descriptorName, "input", "output"); + + // Check the supported data types + std::vector<DataType> supportedTypes = + { + DataType::Float32, + DataType::Float16, + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName); + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], supportedTypes, descriptorName); + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + {workloadInfo.m_InputTensorInfos[0].GetDataType()}, descriptorName); } void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const |