From 3122bd574a3d29774c535ca2136de361da626e88 Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Mon, 3 Jun 2019 16:54:25 +0100 Subject: IVGCVSW-3212 Refactor the Reference BatchNormalization workloads to handle Float32 and QAsymm8 types * Removed the type-specific workload implementations * Added type-independent RefBatchNormalizationWorkload implementation * Reworked BachNormImpl to use decoders/encoders * Improved the validation of the BatchNorm queue descriptor * Fixed unit tests where necessary Change-Id: Icf3fa1332292d38ec2fa0b1cb984cab78426034b Signed-off-by: Matteo Martincigh --- src/backends/backendsCommon/WorkloadData.cpp | 52 +++++++++++++++++++--------- 1 file changed, 36 insertions(+), 16 deletions(-) (limited to 'src/backends/backendsCommon') diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 6d17f3e042..a43619a466 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -684,28 +684,48 @@ void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInf { ValidateNumInputs(workloadInfo, "BatchNormalizationQueueDescriptor", 1); ValidateNumOutputs(workloadInfo, "BatchNormalizationQueueDescriptor", 1); + + const TensorInfo& input = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0]; + + std::vector supportedTypes = + { + DataType::Float16, + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(input, supportedTypes, "BatchNormalizationQueueDescriptor"); + ValidateDataTypes(output, supportedTypes, "BatchNormalizationQueueDescriptor"); + + ValidateDataTypes(output, { input.GetDataType() }, "BatchNormalizationQueueDescriptor"); + + ValidateTensorQuantizationSpace(input, output, "BatchNormalizationQueueDescriptor", "input", "output"); + ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_OutputTensorInfos[0], "BatchNormalizationQueueDescriptor", "input", "output"); - ValidatePointer(m_Mean, "BatchNormalizationQueueDescriptor", "mean"); - ValidatePointer(m_Variance, "BatchNormalizationQueueDescriptor", "variance"); - ValidatePointer(m_Beta, "BatchNormalizationQueueDescriptor", "beta"); - ValidatePointer(m_Gamma, "BatchNormalizationQueueDescriptor", "gamma"); - - ValidateTensorNumDimensions(m_Mean->GetTensorInfo(), "BatchNormalizationQueueDescriptor", 1, "mean"); - ValidateTensorNumDimensions(m_Variance->GetTensorInfo(), "BatchNormalizationQueueDescriptor", 1, "variance"); - ValidateTensorNumDimensions(m_Beta->GetTensorInfo(), "BatchNormalizationQueueDescriptor", 1, "beta"); - ValidateTensorNumDimensions(m_Gamma->GetTensorInfo(), "BatchNormalizationQueueDescriptor", 1, "gamma"); - - ValidateTensorShapesMatch( - m_Mean->GetTensorInfo(), m_Variance->GetTensorInfo(), "BatchNormalizationQueueDescriptor", "mean", "variance"); - ValidateTensorShapesMatch( - m_Mean->GetTensorInfo(), m_Beta->GetTensorInfo(), "BatchNormalizationQueueDescriptor", "mean", "beta"); - ValidateTensorShapesMatch( - m_Mean->GetTensorInfo(), m_Gamma->GetTensorInfo(), "BatchNormalizationQueueDescriptor", "mean", "gamma"); + ValidatePointer(m_Mean, "BatchNormalizationQueueDescriptor", "mean"); + ValidatePointer(m_Variance, "BatchNormalizationQueueDescriptor", "variance"); + ValidatePointer(m_Beta, "BatchNormalizationQueueDescriptor", "beta"); + ValidatePointer(m_Gamma, "BatchNormalizationQueueDescriptor", "gamma"); + + const TensorInfo& mean = m_Mean->GetTensorInfo(); + const TensorInfo& variance = m_Variance->GetTensorInfo(); + const TensorInfo& beta = m_Beta->GetTensorInfo(); + const TensorInfo& gamma = m_Gamma->GetTensorInfo(); + + ValidateTensorNumDimensions(mean, "BatchNormalizationQueueDescriptor", 1, "mean"); + ValidateTensorNumDimensions(variance, "BatchNormalizationQueueDescriptor", 1, "variance"); + ValidateTensorNumDimensions(beta, "BatchNormalizationQueueDescriptor", 1, "beta"); + ValidateTensorNumDimensions(gamma, "BatchNormalizationQueueDescriptor", 1, "gamma"); + + ValidateTensorShapesMatch(mean, variance, "BatchNormalizationQueueDescriptor", "mean", "variance"); + ValidateTensorShapesMatch(mean, beta, "BatchNormalizationQueueDescriptor", "mean", "beta"); + ValidateTensorShapesMatch(mean, gamma, "BatchNormalizationQueueDescriptor", "mean", "gamma"); } void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const -- cgit v1.2.1