diff options
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 52 |
1 files changed, 36 insertions, 16 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 6d17f3e042..a43619a466 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -684,28 +684,48 @@ void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInf { ValidateNumInputs(workloadInfo, "BatchNormalizationQueueDescriptor", 1); ValidateNumOutputs(workloadInfo, "BatchNormalizationQueueDescriptor", 1); + + const TensorInfo& input = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0]; + + std::vector<DataType> supportedTypes = + { + DataType::Float16, + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(input, supportedTypes, "BatchNormalizationQueueDescriptor"); + ValidateDataTypes(output, supportedTypes, "BatchNormalizationQueueDescriptor"); + + ValidateDataTypes(output, { input.GetDataType() }, "BatchNormalizationQueueDescriptor"); + + ValidateTensorQuantizationSpace(input, output, "BatchNormalizationQueueDescriptor", "input", "output"); + ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_OutputTensorInfos[0], "BatchNormalizationQueueDescriptor", "input", "output"); - ValidatePointer(m_Mean, "BatchNormalizationQueueDescriptor", "mean"); - ValidatePointer(m_Variance, "BatchNormalizationQueueDescriptor", "variance"); - ValidatePointer(m_Beta, "BatchNormalizationQueueDescriptor", "beta"); - ValidatePointer(m_Gamma, "BatchNormalizationQueueDescriptor", "gamma"); - - ValidateTensorNumDimensions(m_Mean->GetTensorInfo(), "BatchNormalizationQueueDescriptor", 1, "mean"); - ValidateTensorNumDimensions(m_Variance->GetTensorInfo(), "BatchNormalizationQueueDescriptor", 1, "variance"); - ValidateTensorNumDimensions(m_Beta->GetTensorInfo(), "BatchNormalizationQueueDescriptor", 1, "beta"); - ValidateTensorNumDimensions(m_Gamma->GetTensorInfo(), "BatchNormalizationQueueDescriptor", 1, "gamma"); - - ValidateTensorShapesMatch( - m_Mean->GetTensorInfo(), m_Variance->GetTensorInfo(), "BatchNormalizationQueueDescriptor", "mean", "variance"); - ValidateTensorShapesMatch( - m_Mean->GetTensorInfo(), m_Beta->GetTensorInfo(), "BatchNormalizationQueueDescriptor", "mean", "beta"); - ValidateTensorShapesMatch( - m_Mean->GetTensorInfo(), m_Gamma->GetTensorInfo(), "BatchNormalizationQueueDescriptor", "mean", "gamma"); + ValidatePointer(m_Mean, "BatchNormalizationQueueDescriptor", "mean"); + ValidatePointer(m_Variance, "BatchNormalizationQueueDescriptor", "variance"); + ValidatePointer(m_Beta, "BatchNormalizationQueueDescriptor", "beta"); + ValidatePointer(m_Gamma, "BatchNormalizationQueueDescriptor", "gamma"); + + const TensorInfo& mean = m_Mean->GetTensorInfo(); + const TensorInfo& variance = m_Variance->GetTensorInfo(); + const TensorInfo& beta = m_Beta->GetTensorInfo(); + const TensorInfo& gamma = m_Gamma->GetTensorInfo(); + + ValidateTensorNumDimensions(mean, "BatchNormalizationQueueDescriptor", 1, "mean"); + ValidateTensorNumDimensions(variance, "BatchNormalizationQueueDescriptor", 1, "variance"); + ValidateTensorNumDimensions(beta, "BatchNormalizationQueueDescriptor", 1, "beta"); + ValidateTensorNumDimensions(gamma, "BatchNormalizationQueueDescriptor", 1, "gamma"); + + ValidateTensorShapesMatch(mean, variance, "BatchNormalizationQueueDescriptor", "mean", "variance"); + ValidateTensorShapesMatch(mean, beta, "BatchNormalizationQueueDescriptor", "mean", "beta"); + ValidateTensorShapesMatch(mean, gamma, "BatchNormalizationQueueDescriptor", "mean", "gamma"); } void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const |