From 3122bd574a3d29774c535ca2136de361da626e88 Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Mon, 3 Jun 2019 16:54:25 +0100 Subject: IVGCVSW-3212 Refactor the Reference BatchNormalization workloads to handle Float32 and QAsymm8 types * Removed the type-specific workload implementations * Added type-independent RefBatchNormalizationWorkload implementation * Reworked BachNormImpl to use decoders/encoders * Improved the validation of the BatchNorm queue descriptor * Fixed unit tests where necessary Change-Id: Icf3fa1332292d38ec2fa0b1cb984cab78426034b Signed-off-by: Matteo Martincigh --- src/backends/reference/workloads/BatchNormImpl.cpp | 82 ++++++++++++++++++++++ src/backends/reference/workloads/BatchNormImpl.hpp | 62 +++------------- src/backends/reference/workloads/CMakeLists.txt | 7 +- src/backends/reference/workloads/Encoders.hpp | 2 + .../RefBatchNormalizationFloat32Workload.cpp | 38 ---------- .../RefBatchNormalizationFloat32Workload.hpp | 28 -------- .../RefBatchNormalizationUint8Workload.cpp | 47 ------------- .../RefBatchNormalizationUint8Workload.hpp | 28 -------- .../workloads/RefBatchNormalizationWorkload.cpp | 45 ++++++++++++ .../workloads/RefBatchNormalizationWorkload.hpp | 28 ++++++++ src/backends/reference/workloads/RefWorkloads.hpp | 3 +- 11 files changed, 172 insertions(+), 198 deletions(-) create mode 100644 src/backends/reference/workloads/BatchNormImpl.cpp delete mode 100644 src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.cpp delete mode 100644 src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.hpp delete mode 100644 src/backends/reference/workloads/RefBatchNormalizationUint8Workload.cpp delete mode 100644 src/backends/reference/workloads/RefBatchNormalizationUint8Workload.hpp create mode 100644 src/backends/reference/workloads/RefBatchNormalizationWorkload.cpp create mode 100644 src/backends/reference/workloads/RefBatchNormalizationWorkload.hpp (limited to 'src/backends/reference/workloads') diff --git a/src/backends/reference/workloads/BatchNormImpl.cpp b/src/backends/reference/workloads/BatchNormImpl.cpp new file mode 100644 index 0000000000..36e96d3fec --- /dev/null +++ b/src/backends/reference/workloads/BatchNormImpl.cpp @@ -0,0 +1,82 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "BatchNormImpl.hpp" +#include "RefWorkloadUtils.hpp" + +#include + +#include + +#include + +namespace armnn +{ + +void BatchNormImpl(const BatchNormalizationQueueDescriptor& data, + Decoder& meanDecoder, + Decoder& varianceDecoder, + Decoder& betaDecoder, + Decoder& gammaDecoder, + Decoder& inputDecoder, + Encoder& outputEncoder) +{ + const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[0]); + const TensorShape inputShape = inputInfo.GetShape(); + + armnnUtils::DataLayoutIndexed dataLayout(data.m_Parameters.m_DataLayout); + + unsigned int inputBatches = inputShape[0]; + unsigned int inputHeight = inputShape[dataLayout.GetHeightIndex()]; + unsigned int inputWidth = inputShape[dataLayout.GetWidthIndex()]; + unsigned int inputChannels = inputShape[dataLayout.GetChannelsIndex()]; + + for (unsigned int c = 0; c < inputChannels; c++) + { + meanDecoder[c]; + varianceDecoder[c]; + betaDecoder[c]; + gammaDecoder[c]; + float mean = meanDecoder.Get(); + float var = varianceDecoder.Get(); + float beta = betaDecoder.Get(); + float gamma = gammaDecoder.Get(); + + float mult = gamma / sqrtf(var + data.m_Parameters.m_Eps); + float add = beta - mult * mean; + + for (unsigned int n = 0; n < inputBatches; n++) + { + for (unsigned int h = 0; h < inputHeight; h++) + { + for (unsigned int w = 0; w < inputWidth; w++) + { + unsigned int index = 0; + + if (dataLayout == DataLayout::NHWC) + { + index = n * inputHeight * inputWidth * inputChannels + + h * inputWidth * inputChannels + + w * inputChannels + + c; + } + else // dataLayout == DataLayout::NCHW + { + index = n * inputHeight * inputWidth * inputChannels + + c * inputHeight * inputWidth + + h * inputWidth + + w; + } + + inputDecoder[index]; + outputEncoder[index]; + outputEncoder.Set(mult * inputDecoder.Get() + add); + } + } + } + } +} + +} // namespace armnn diff --git a/src/backends/reference/workloads/BatchNormImpl.hpp b/src/backends/reference/workloads/BatchNormImpl.hpp index 799e7a327b..c0250b9e0f 100644 --- a/src/backends/reference/workloads/BatchNormImpl.hpp +++ b/src/backends/reference/workloads/BatchNormImpl.hpp @@ -5,60 +5,20 @@ #pragma once -#include "RefWorkloadUtils.hpp" -#include "TensorBufferArrayView.hpp" +#include "Encoders.hpp" +#include "Decoders.hpp" -#include - -#include - -#include +#include namespace armnn { -template -static void BatchNormImpl(NormData data, - const float* varIn, - const float* meanIn, - const float* gammaIn, - const float* betaIn, - float* outputData, - const float* inputData) -{ - const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[0]); - const TensorInfo& outputInfo = GetTensorInfo(data.m_Outputs[0]); - - TensorBufferArrayView input(inputInfo.GetShape(), - inputData, - data.m_Parameters.m_DataLayout); - TensorBufferArrayView output(outputInfo.GetShape(), - outputData, - data.m_Parameters.m_DataLayout); - - armnnUtils::DataLayoutIndexed dataLayout(data.m_Parameters.m_DataLayout); - - for (unsigned int c = 0; c < inputInfo.GetShape()[dataLayout.GetChannelsIndex()]; c++) - { - float var = varIn[c]; - float mean = meanIn[c]; - float gamma = gammaIn[c]; - float beta = betaIn[c]; - - float mult = gamma / sqrtf(var + data.m_Parameters.m_Eps); - float add = beta - mult * mean; - - for (unsigned int n = 0; n < inputInfo.GetShape()[0]; n++) - { - for (unsigned int h = 0; h < inputInfo.GetShape()[dataLayout.GetHeightIndex()]; h++) - { - for (unsigned int w = 0; w < inputInfo.GetShape()[dataLayout.GetWidthIndex()]; w++) - { - output.Get(n, c, h, w) = mult * input.Get(n, c, h, w) + add; - } - } - } - } -} +void BatchNormImpl(const BatchNormalizationQueueDescriptor& data, + Decoder& meanIn, + Decoder& varIn, + Decoder& betaIn, + Decoder& gammaIn, + Decoder& inputData, + Encoder& outputData); -} //namespace armnn +} // namespace armnn diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index df126c4308..cdca22da31 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -7,6 +7,7 @@ list(APPEND armnnRefBackendWorkloads_sources Activation.cpp Activation.hpp BaseIterator.hpp + BatchNormImpl.cpp BatchNormImpl.hpp BatchToSpaceNd.cpp BatchToSpaceNd.hpp @@ -37,10 +38,8 @@ list(APPEND armnnRefBackendWorkloads_sources Pooling2d.hpp RefActivationWorkload.cpp RefActivationWorkload.hpp - RefBatchNormalizationFloat32Workload.cpp - RefBatchNormalizationFloat32Workload.hpp - RefBatchNormalizationUint8Workload.cpp - RefBatchNormalizationUint8Workload.hpp + RefBatchNormalizationWorkload.cpp + RefBatchNormalizationWorkload.hpp RefBatchToSpaceNdFloat32Workload.cpp RefBatchToSpaceNdFloat32Workload.hpp RefBatchToSpaceNdUint8Workload.cpp diff --git a/src/backends/reference/workloads/Encoders.hpp b/src/backends/reference/workloads/Encoders.hpp index 547bead98a..af3b937c2a 100644 --- a/src/backends/reference/workloads/Encoders.hpp +++ b/src/backends/reference/workloads/Encoders.hpp @@ -7,6 +7,8 @@ #include "BaseIterator.hpp" +#include + namespace armnn { diff --git a/src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.cpp b/src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.cpp deleted file mode 100644 index 313af9c438..0000000000 --- a/src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.cpp +++ /dev/null @@ -1,38 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "RefBatchNormalizationFloat32Workload.hpp" - -#include "BatchNormImpl.hpp" -#include "RefWorkloadUtils.hpp" - -#include "Profiling.hpp" - -namespace armnn -{ -RefBatchNormalizationFloat32Workload::RefBatchNormalizationFloat32Workload( - const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) - : Float32Workload(descriptor, info), - m_Mean(std::make_unique(*(descriptor.m_Mean))), - m_Variance(std::make_unique(*(descriptor.m_Variance))), - m_Beta(std::make_unique(*(descriptor.m_Beta))), - m_Gamma(std::make_unique(*(descriptor.m_Gamma))) {} - -void RefBatchNormalizationFloat32Workload::Execute() const -{ - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefBatchNormalizationFloat32Workload_Execute"); - - const float* var = m_Variance->GetConstTensor(); - const float* mean = m_Mean->GetConstTensor(); - const float* gamma = m_Gamma->GetConstTensor(); - const float* beta = m_Beta->GetConstTensor(); - - auto inputData = GetInputTensorDataFloat(0, m_Data); - auto outputData = GetOutputTensorDataFloat(0, m_Data); - - BatchNormImpl(m_Data, var, mean, gamma, beta, outputData, inputData); -} - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.hpp b/src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.hpp deleted file mode 100644 index 9f92899f4f..0000000000 --- a/src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.hpp +++ /dev/null @@ -1,28 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include -#include - -namespace armnn -{ - -class RefBatchNormalizationFloat32Workload : public Float32Workload -{ -public: - explicit RefBatchNormalizationFloat32Workload(const BatchNormalizationQueueDescriptor& descriptor, - const WorkloadInfo& info); - virtual void Execute() const override; - -private: - std::unique_ptr m_Mean; - std::unique_ptr m_Variance; - std::unique_ptr m_Beta; - std::unique_ptr m_Gamma; -}; - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefBatchNormalizationUint8Workload.cpp b/src/backends/reference/workloads/RefBatchNormalizationUint8Workload.cpp deleted file mode 100644 index e248ad4b9d..0000000000 --- a/src/backends/reference/workloads/RefBatchNormalizationUint8Workload.cpp +++ /dev/null @@ -1,47 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "RefBatchNormalizationUint8Workload.hpp" - -#include "BatchNormImpl.hpp" -#include "RefWorkloadUtils.hpp" - -#include "Profiling.hpp" - -#include - -namespace armnn -{ -RefBatchNormalizationUint8Workload::RefBatchNormalizationUint8Workload( - const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) - : Uint8Workload(descriptor, info), - m_Mean(std::make_unique(*(descriptor.m_Mean))), - m_Variance(std::make_unique(*(descriptor.m_Variance))), - m_Beta(std::make_unique(*(descriptor.m_Beta))), - m_Gamma(std::make_unique(*(descriptor.m_Gamma))) {} - -void RefBatchNormalizationUint8Workload::Execute() const -{ - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefBatchNormalizationUint8Workload_Execute"); - - const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); - const TensorInfo& varInfo = GetTensorInfo(m_Variance.get()); - const TensorInfo& meanInfo = GetTensorInfo(m_Mean.get()); - const TensorInfo& gammaInfo = GetTensorInfo(m_Gamma.get()); - const TensorInfo& betaInfo = GetTensorInfo(m_Beta.get()); - const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); - - auto input = Dequantize(GetInputTensorDataU8(0, m_Data), inputInfo0); - auto var = Dequantize(m_Variance->GetConstTensor(), varInfo); - auto mean = Dequantize(m_Mean->GetConstTensor(), meanInfo); - auto gamma = Dequantize(m_Gamma->GetConstTensor(), gammaInfo); - auto beta = Dequantize(m_Beta->GetConstTensor(), betaInfo); - - std::vector results(outputInfo.GetNumElements()); - BatchNormImpl(m_Data, var.data(), mean.data(), gamma.data(), beta.data(), results.data(), input.data()); - Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo); -} - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefBatchNormalizationUint8Workload.hpp b/src/backends/reference/workloads/RefBatchNormalizationUint8Workload.hpp deleted file mode 100644 index 7c288a5c07..0000000000 --- a/src/backends/reference/workloads/RefBatchNormalizationUint8Workload.hpp +++ /dev/null @@ -1,28 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include -#include - -namespace armnn -{ - -class RefBatchNormalizationUint8Workload : public Uint8Workload -{ -public: - explicit RefBatchNormalizationUint8Workload(const BatchNormalizationQueueDescriptor& descriptor, - const WorkloadInfo& info); - virtual void Execute() const override; - -private: - std::unique_ptr m_Mean; - std::unique_ptr m_Variance; - std::unique_ptr m_Beta; - std::unique_ptr m_Gamma; -}; - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefBatchNormalizationWorkload.cpp b/src/backends/reference/workloads/RefBatchNormalizationWorkload.cpp new file mode 100644 index 0000000000..b43b104459 --- /dev/null +++ b/src/backends/reference/workloads/RefBatchNormalizationWorkload.cpp @@ -0,0 +1,45 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefBatchNormalizationWorkload.hpp" + +#include "BatchNormImpl.hpp" +#include "RefWorkloadUtils.hpp" + +#include "Profiling.hpp" + +namespace armnn +{ + +RefBatchNormalizationWorkload::RefBatchNormalizationWorkload(const BatchNormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload(descriptor, info) + , m_Mean (std::make_unique(*(descriptor.m_Mean))) + , m_Variance(std::make_unique(*(descriptor.m_Variance))) + , m_Beta (std::make_unique(*(descriptor.m_Beta))) + , m_Gamma (std::make_unique(*(descriptor.m_Gamma))) +{} + +void RefBatchNormalizationWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefBatchNormalizationWorkload_Execute"); + + std::unique_ptr> meanDecoder = MakeDecoder(GetTensorInfo(m_Mean.get()), + m_Mean.get()->Map(true)); + std::unique_ptr> varianceDecoder = MakeDecoder(GetTensorInfo(m_Variance.get()), + m_Variance.get()->Map(true)); + std::unique_ptr> gammaDecoder = MakeDecoder(GetTensorInfo(m_Gamma.get()), + m_Gamma.get()->Map(true)); + std::unique_ptr> betaDecoder = MakeDecoder(GetTensorInfo(m_Beta.get()), + m_Beta.get()->Map(true)); + std::unique_ptr> inputDecoder = MakeDecoder(GetTensorInfo(m_Data.m_Inputs[0]), + m_Data.m_Inputs[0]->Map()); + std::unique_ptr> outputEncoder = MakeEncoder(GetTensorInfo(m_Data.m_Outputs[0]), + m_Data.m_Outputs[0]->Map()); + + BatchNormImpl(m_Data, *meanDecoder, *varianceDecoder, *betaDecoder, *gammaDecoder, *inputDecoder, *outputEncoder); +} + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefBatchNormalizationWorkload.hpp b/src/backends/reference/workloads/RefBatchNormalizationWorkload.hpp new file mode 100644 index 0000000000..9e71e7b4c5 --- /dev/null +++ b/src/backends/reference/workloads/RefBatchNormalizationWorkload.hpp @@ -0,0 +1,28 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include +#include + +namespace armnn +{ + +class RefBatchNormalizationWorkload : public BaseWorkload +{ +public: + explicit RefBatchNormalizationWorkload(const BatchNormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info); + virtual void Execute() const override; + +private: + std::unique_ptr m_Mean; + std::unique_ptr m_Variance; + std::unique_ptr m_Beta; + std::unique_ptr m_Gamma; +}; + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index c8c26b0b83..7ccd4efc54 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -21,7 +21,7 @@ #include "RefGatherWorkload.hpp" #include "Softmax.hpp" #include "TensorBufferArrayView.hpp" -#include "RefBatchNormalizationFloat32Workload.hpp" +#include "RefBatchNormalizationWorkload.hpp" #include "Splitter.hpp" #include "RefDepthwiseConvolution2dWorkload.hpp" #include "FullyConnected.hpp" @@ -29,7 +29,6 @@ #include "RefFloorWorkload.hpp" #include "RefSoftmaxWorkload.hpp" #include "RefResizeBilinearFloat32Workload.hpp" -#include "RefBatchNormalizationUint8Workload.hpp" #include "ResizeBilinear.hpp" #include "RefNormalizationFloat32Workload.hpp" #include "RefDetectionPostProcessFloat32Workload.hpp" -- cgit v1.2.1