diff options
author | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-06-03 16:54:25 +0100 |
---|---|---|
committer | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2019-06-04 15:13:51 +0000 |
commit | 3122bd574a3d29774c535ca2136de361da626e88 (patch) | |
tree | c2fcc19be67f5a35c30d042b80ba3157ef87bd21 /src/backends/backendsCommon/WorkloadData.cpp | |
parent | 550fe36f687e73c78b57ebfeee9f98fd35f40f24 (diff) | |
download | armnn-3122bd574a3d29774c535ca2136de361da626e88.tar.gz |
IVGCVSW-3212 Refactor the Reference BatchNormalization workloads to
handle Float32 and QAsymm8 types
* Removed the type-specific workload implementations
* Added type-independent RefBatchNormalizationWorkload implementation
* Reworked BachNormImpl to use decoders/encoders
* Improved the validation of the BatchNorm queue descriptor
* Fixed unit tests where necessary
Change-Id: Icf3fa1332292d38ec2fa0b1cb984cab78426034b
Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 52 |
1 files changed, 36 insertions, 16 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 6d17f3e042..a43619a466 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -684,28 +684,48 @@ void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInf { ValidateNumInputs(workloadInfo, "BatchNormalizationQueueDescriptor", 1); ValidateNumOutputs(workloadInfo, "BatchNormalizationQueueDescriptor", 1); + + const TensorInfo& input = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0]; + + std::vector<DataType> supportedTypes = + { + DataType::Float16, + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(input, supportedTypes, "BatchNormalizationQueueDescriptor"); + ValidateDataTypes(output, supportedTypes, "BatchNormalizationQueueDescriptor"); + + ValidateDataTypes(output, { input.GetDataType() }, "BatchNormalizationQueueDescriptor"); + + ValidateTensorQuantizationSpace(input, output, "BatchNormalizationQueueDescriptor", "input", "output"); + ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_OutputTensorInfos[0], "BatchNormalizationQueueDescriptor", "input", "output"); - ValidatePointer(m_Mean, "BatchNormalizationQueueDescriptor", "mean"); - ValidatePointer(m_Variance, "BatchNormalizationQueueDescriptor", "variance"); - ValidatePointer(m_Beta, "BatchNormalizationQueueDescriptor", "beta"); - ValidatePointer(m_Gamma, "BatchNormalizationQueueDescriptor", "gamma"); - - ValidateTensorNumDimensions(m_Mean->GetTensorInfo(), "BatchNormalizationQueueDescriptor", 1, "mean"); - ValidateTensorNumDimensions(m_Variance->GetTensorInfo(), "BatchNormalizationQueueDescriptor", 1, "variance"); - ValidateTensorNumDimensions(m_Beta->GetTensorInfo(), "BatchNormalizationQueueDescriptor", 1, "beta"); - ValidateTensorNumDimensions(m_Gamma->GetTensorInfo(), "BatchNormalizationQueueDescriptor", 1, "gamma"); - - ValidateTensorShapesMatch( - m_Mean->GetTensorInfo(), m_Variance->GetTensorInfo(), "BatchNormalizationQueueDescriptor", "mean", "variance"); - ValidateTensorShapesMatch( - m_Mean->GetTensorInfo(), m_Beta->GetTensorInfo(), "BatchNormalizationQueueDescriptor", "mean", "beta"); - ValidateTensorShapesMatch( - m_Mean->GetTensorInfo(), m_Gamma->GetTensorInfo(), "BatchNormalizationQueueDescriptor", "mean", "gamma"); + ValidatePointer(m_Mean, "BatchNormalizationQueueDescriptor", "mean"); + ValidatePointer(m_Variance, "BatchNormalizationQueueDescriptor", "variance"); + ValidatePointer(m_Beta, "BatchNormalizationQueueDescriptor", "beta"); + ValidatePointer(m_Gamma, "BatchNormalizationQueueDescriptor", "gamma"); + + const TensorInfo& mean = m_Mean->GetTensorInfo(); + const TensorInfo& variance = m_Variance->GetTensorInfo(); + const TensorInfo& beta = m_Beta->GetTensorInfo(); + const TensorInfo& gamma = m_Gamma->GetTensorInfo(); + + ValidateTensorNumDimensions(mean, "BatchNormalizationQueueDescriptor", 1, "mean"); + ValidateTensorNumDimensions(variance, "BatchNormalizationQueueDescriptor", 1, "variance"); + ValidateTensorNumDimensions(beta, "BatchNormalizationQueueDescriptor", 1, "beta"); + ValidateTensorNumDimensions(gamma, "BatchNormalizationQueueDescriptor", 1, "gamma"); + + ValidateTensorShapesMatch(mean, variance, "BatchNormalizationQueueDescriptor", "mean", "variance"); + ValidateTensorShapesMatch(mean, beta, "BatchNormalizationQueueDescriptor", "mean", "beta"); + ValidateTensorShapesMatch(mean, gamma, "BatchNormalizationQueueDescriptor", "mean", "gamma"); } void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const |