From 3122bd574a3d29774c535ca2136de361da626e88 Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Mon, 3 Jun 2019 16:54:25 +0100 Subject: IVGCVSW-3212 Refactor the Reference BatchNormalization workloads to handle Float32 and QAsymm8 types * Removed the type-specific workload implementations * Added type-independent RefBatchNormalizationWorkload implementation * Reworked BachNormImpl to use decoders/encoders * Improved the validation of the BatchNorm queue descriptor * Fixed unit tests where necessary Change-Id: Icf3fa1332292d38ec2fa0b1cb984cab78426034b Signed-off-by: Matteo Martincigh --- src/armnn/test/CreateWorkload.hpp | 9 ++- src/backends/backendsCommon/WorkloadData.cpp | 52 +++++++++----- src/backends/reference/RefLayerSupport.cpp | 42 ++++++++--- src/backends/reference/RefWorkloadFactory.cpp | 2 +- src/backends/reference/backend.mk | 4 +- .../reference/test/RefCreateWorkloadTests.cpp | 13 ++-- src/backends/reference/workloads/BatchNormImpl.cpp | 82 ++++++++++++++++++++++ src/backends/reference/workloads/BatchNormImpl.hpp | 62 +++------------- src/backends/reference/workloads/CMakeLists.txt | 7 +- src/backends/reference/workloads/Encoders.hpp | 2 + .../RefBatchNormalizationFloat32Workload.cpp | 38 ---------- .../RefBatchNormalizationFloat32Workload.hpp | 28 -------- .../RefBatchNormalizationUint8Workload.cpp | 47 ------------- .../RefBatchNormalizationUint8Workload.hpp | 28 -------- .../workloads/RefBatchNormalizationWorkload.cpp | 45 ++++++++++++ .../workloads/RefBatchNormalizationWorkload.hpp | 28 ++++++++ src/backends/reference/workloads/RefWorkloads.hpp | 3 +- 17 files changed, 254 insertions(+), 238 deletions(-) create mode 100644 src/backends/reference/workloads/BatchNormImpl.cpp delete mode 100644 src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.cpp delete mode 100644 src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.hpp delete mode 100644 src/backends/reference/workloads/RefBatchNormalizationUint8Workload.cpp delete mode 100644 src/backends/reference/workloads/RefBatchNormalizationUint8Workload.hpp create mode 100644 src/backends/reference/workloads/RefBatchNormalizationWorkload.cpp create mode 100644 src/backends/reference/workloads/RefBatchNormalizationWorkload.hpp diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp index 87df00af3c..c4b191a29f 100644 --- a/src/armnn/test/CreateWorkload.hpp +++ b/src/armnn/test/CreateWorkload.hpp @@ -134,11 +134,10 @@ std::unique_ptr CreateElementwiseWorkloadTest(armnn::IWorkloadFact return workload; } -template -std::unique_ptr CreateBatchNormalizationWorkloadTest( +template +std::unique_ptr CreateBatchNormalizationWorkloadTest( armnn::IWorkloadFactory& factory, armnn::Graph& graph, DataLayout dataLayout = DataLayout::NCHW) { - TensorShape tensorShape; switch (dataLayout) { @@ -171,14 +170,14 @@ std::unique_ptr CreateBatchNormalizationWorkl Layer* const input = graph.AddLayer(0, "input"); Layer* const output = graph.AddLayer(0, "output"); - //Connects up. + // Connects up. armnn::TensorInfo tensorInfo(tensorShape, DataType); Connect(input, layer, tensorInfo); Connect(layer, output, tensorInfo); CreateTensorHandles(graph, factory); // Makes the workload and checks it. - auto workload = MakeAndCheckWorkload(*layer, graph, factory); + auto workload = MakeAndCheckWorkload(*layer, graph, factory); BatchNormalizationQueueDescriptor queueDescriptor = workload->GetData(); BOOST_TEST(queueDescriptor.m_Parameters.m_Eps == 0.05f); BOOST_TEST(queueDescriptor.m_Inputs.size() == 1); diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 6d17f3e042..a43619a466 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -684,28 +684,48 @@ void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInf { ValidateNumInputs(workloadInfo, "BatchNormalizationQueueDescriptor", 1); ValidateNumOutputs(workloadInfo, "BatchNormalizationQueueDescriptor", 1); + + const TensorInfo& input = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0]; + + std::vector supportedTypes = + { + DataType::Float16, + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(input, supportedTypes, "BatchNormalizationQueueDescriptor"); + ValidateDataTypes(output, supportedTypes, "BatchNormalizationQueueDescriptor"); + + ValidateDataTypes(output, { input.GetDataType() }, "BatchNormalizationQueueDescriptor"); + + ValidateTensorQuantizationSpace(input, output, "BatchNormalizationQueueDescriptor", "input", "output"); + ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_OutputTensorInfos[0], "BatchNormalizationQueueDescriptor", "input", "output"); - ValidatePointer(m_Mean, "BatchNormalizationQueueDescriptor", "mean"); - ValidatePointer(m_Variance, "BatchNormalizationQueueDescriptor", "variance"); - ValidatePointer(m_Beta, "BatchNormalizationQueueDescriptor", "beta"); - ValidatePointer(m_Gamma, "BatchNormalizationQueueDescriptor", "gamma"); - - ValidateTensorNumDimensions(m_Mean->GetTensorInfo(), "BatchNormalizationQueueDescriptor", 1, "mean"); - ValidateTensorNumDimensions(m_Variance->GetTensorInfo(), "BatchNormalizationQueueDescriptor", 1, "variance"); - ValidateTensorNumDimensions(m_Beta->GetTensorInfo(), "BatchNormalizationQueueDescriptor", 1, "beta"); - ValidateTensorNumDimensions(m_Gamma->GetTensorInfo(), "BatchNormalizationQueueDescriptor", 1, "gamma"); - - ValidateTensorShapesMatch( - m_Mean->GetTensorInfo(), m_Variance->GetTensorInfo(), "BatchNormalizationQueueDescriptor", "mean", "variance"); - ValidateTensorShapesMatch( - m_Mean->GetTensorInfo(), m_Beta->GetTensorInfo(), "BatchNormalizationQueueDescriptor", "mean", "beta"); - ValidateTensorShapesMatch( - m_Mean->GetTensorInfo(), m_Gamma->GetTensorInfo(), "BatchNormalizationQueueDescriptor", "mean", "gamma"); + ValidatePointer(m_Mean, "BatchNormalizationQueueDescriptor", "mean"); + ValidatePointer(m_Variance, "BatchNormalizationQueueDescriptor", "variance"); + ValidatePointer(m_Beta, "BatchNormalizationQueueDescriptor", "beta"); + ValidatePointer(m_Gamma, "BatchNormalizationQueueDescriptor", "gamma"); + + const TensorInfo& mean = m_Mean->GetTensorInfo(); + const TensorInfo& variance = m_Variance->GetTensorInfo(); + const TensorInfo& beta = m_Beta->GetTensorInfo(); + const TensorInfo& gamma = m_Gamma->GetTensorInfo(); + + ValidateTensorNumDimensions(mean, "BatchNormalizationQueueDescriptor", 1, "mean"); + ValidateTensorNumDimensions(variance, "BatchNormalizationQueueDescriptor", 1, "variance"); + ValidateTensorNumDimensions(beta, "BatchNormalizationQueueDescriptor", 1, "beta"); + ValidateTensorNumDimensions(gamma, "BatchNormalizationQueueDescriptor", 1, "gamma"); + + ValidateTensorShapesMatch(mean, variance, "BatchNormalizationQueueDescriptor", "mean", "variance"); + ValidateTensorShapesMatch(mean, beta, "BatchNormalizationQueueDescriptor", "mean", "beta"); + ValidateTensorShapesMatch(mean, gamma, "BatchNormalizationQueueDescriptor", "mean", "gamma"); } void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index edd552b2ac..aeff51d853 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -280,22 +280,44 @@ bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0, bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input, const TensorInfo& output, const TensorInfo& mean, - const TensorInfo& var, + const TensorInfo& variance, const TensorInfo& beta, const TensorInfo& gamma, const BatchNormalizationDescriptor& descriptor, Optional reasonIfUnsupported) const { - ignore_unused(output); - ignore_unused(mean); - ignore_unused(var); - ignore_unused(beta); - ignore_unused(gamma); ignore_unused(descriptor); - return IsSupportedForDataTypeRef(reasonIfUnsupported, - input.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); + + std::array supportedTypes = + { + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + bool supported = true; + + supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, + "Reference batch normalization: input is not a supported type."); + + supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, + "Reference batch normalization: output is not a supported type."); + + supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, + "Reference batch normalization: input and output types are mismatched"); + + supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported, + "Reference batch normalization: mean is not a supported type."); + + supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported, + "Reference batch normalization: variance is not a supported type."); + + supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported, + "Reference batch normalization: beta is not a supported type."); + + supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported, + "Reference batch normalization: gamma is not a supported type."); + + return supported; } bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 161065550d..d103f56c23 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -220,7 +220,7 @@ std::unique_ptr RefWorkloadFactory::CreateMultiplication( std::unique_ptr RefWorkloadFactory::CreateBatchNormalization( const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload(descriptor, info); + return std::make_unique(descriptor, info); } std::unique_ptr RefWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor, diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index f371c8bd0b..81b6de18e4 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -12,6 +12,7 @@ BACKEND_SOURCES := \ RefLayerSupport.cpp \ RefWorkloadFactory.cpp \ workloads/Activation.cpp \ + workloads/BatchNormImpl.cpp \ workloads/BatchToSpaceNd.cpp \ workloads/Broadcast.cpp \ workloads/ConvImpl.cpp \ @@ -25,8 +26,7 @@ BACKEND_SOURCES := \ workloads/Pad.cpp \ workloads/Pooling2d.cpp \ workloads/RefActivationWorkload.cpp \ - workloads/RefBatchNormalizationFloat32Workload.cpp \ - workloads/RefBatchNormalizationUint8Workload.cpp \ + workloads/RefBatchNormalizationWorkload.cpp \ workloads/RefBatchToSpaceNdFloat32Workload.cpp \ workloads/RefBatchToSpaceNdUint8Workload.cpp \ workloads/RefConcatWorkload.cpp \ diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index 7c5712b915..a0c614564d 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -181,8 +181,9 @@ static void RefCreateBatchNormalizationWorkloadTest(DataLayout dataLayout) { Graph graph; RefWorkloadFactory factory; - auto workload = - CreateBatchNormalizationWorkloadTest(factory, graph, dataLayout); + auto workload = CreateBatchNormalizationWorkloadTest(factory, + graph, + dataLayout); TensorShape inputShape; TensorShape outputShape; @@ -206,25 +207,25 @@ static void RefCreateBatchNormalizationWorkloadTest(DataLayout dataLayout) BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloat32Workload) { - RefCreateBatchNormalizationWorkloadTest + RefCreateBatchNormalizationWorkloadTest (DataLayout::NCHW); } BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloat32WorkloadNhwc) { - RefCreateBatchNormalizationWorkloadTest + RefCreateBatchNormalizationWorkloadTest (DataLayout::NHWC); } BOOST_AUTO_TEST_CASE(CreateBatchNormalizationUint8Workload) { - RefCreateBatchNormalizationWorkloadTest + RefCreateBatchNormalizationWorkloadTest (DataLayout::NCHW); } BOOST_AUTO_TEST_CASE(CreateBatchNormalizationUint8WorkloadNhwc) { - RefCreateBatchNormalizationWorkloadTest + RefCreateBatchNormalizationWorkloadTest (DataLayout::NHWC); } diff --git a/src/backends/reference/workloads/BatchNormImpl.cpp b/src/backends/reference/workloads/BatchNormImpl.cpp new file mode 100644 index 0000000000..36e96d3fec --- /dev/null +++ b/src/backends/reference/workloads/BatchNormImpl.cpp @@ -0,0 +1,82 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "BatchNormImpl.hpp" +#include "RefWorkloadUtils.hpp" + +#include + +#include + +#include + +namespace armnn +{ + +void BatchNormImpl(const BatchNormalizationQueueDescriptor& data, + Decoder& meanDecoder, + Decoder& varianceDecoder, + Decoder& betaDecoder, + Decoder& gammaDecoder, + Decoder& inputDecoder, + Encoder& outputEncoder) +{ + const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[0]); + const TensorShape inputShape = inputInfo.GetShape(); + + armnnUtils::DataLayoutIndexed dataLayout(data.m_Parameters.m_DataLayout); + + unsigned int inputBatches = inputShape[0]; + unsigned int inputHeight = inputShape[dataLayout.GetHeightIndex()]; + unsigned int inputWidth = inputShape[dataLayout.GetWidthIndex()]; + unsigned int inputChannels = inputShape[dataLayout.GetChannelsIndex()]; + + for (unsigned int c = 0; c < inputChannels; c++) + { + meanDecoder[c]; + varianceDecoder[c]; + betaDecoder[c]; + gammaDecoder[c]; + float mean = meanDecoder.Get(); + float var = varianceDecoder.Get(); + float beta = betaDecoder.Get(); + float gamma = gammaDecoder.Get(); + + float mult = gamma / sqrtf(var + data.m_Parameters.m_Eps); + float add = beta - mult * mean; + + for (unsigned int n = 0; n < inputBatches; n++) + { + for (unsigned int h = 0; h < inputHeight; h++) + { + for (unsigned int w = 0; w < inputWidth; w++) + { + unsigned int index = 0; + + if (dataLayout == DataLayout::NHWC) + { + index = n * inputHeight * inputWidth * inputChannels + + h * inputWidth * inputChannels + + w * inputChannels + + c; + } + else // dataLayout == DataLayout::NCHW + { + index = n * inputHeight * inputWidth * inputChannels + + c * inputHeight * inputWidth + + h * inputWidth + + w; + } + + inputDecoder[index]; + outputEncoder[index]; + outputEncoder.Set(mult * inputDecoder.Get() + add); + } + } + } + } +} + +} // namespace armnn diff --git a/src/backends/reference/workloads/BatchNormImpl.hpp b/src/backends/reference/workloads/BatchNormImpl.hpp index 799e7a327b..c0250b9e0f 100644 --- a/src/backends/reference/workloads/BatchNormImpl.hpp +++ b/src/backends/reference/workloads/BatchNormImpl.hpp @@ -5,60 +5,20 @@ #pragma once -#include "RefWorkloadUtils.hpp" -#include "TensorBufferArrayView.hpp" +#include "Encoders.hpp" +#include "Decoders.hpp" -#include - -#include - -#include +#include namespace armnn { -template -static void BatchNormImpl(NormData data, - const float* varIn, - const float* meanIn, - const float* gammaIn, - const float* betaIn, - float* outputData, - const float* inputData) -{ - const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[0]); - const TensorInfo& outputInfo = GetTensorInfo(data.m_Outputs[0]); - - TensorBufferArrayView input(inputInfo.GetShape(), - inputData, - data.m_Parameters.m_DataLayout); - TensorBufferArrayView output(outputInfo.GetShape(), - outputData, - data.m_Parameters.m_DataLayout); - - armnnUtils::DataLayoutIndexed dataLayout(data.m_Parameters.m_DataLayout); - - for (unsigned int c = 0; c < inputInfo.GetShape()[dataLayout.GetChannelsIndex()]; c++) - { - float var = varIn[c]; - float mean = meanIn[c]; - float gamma = gammaIn[c]; - float beta = betaIn[c]; - - float mult = gamma / sqrtf(var + data.m_Parameters.m_Eps); - float add = beta - mult * mean; - - for (unsigned int n = 0; n < inputInfo.GetShape()[0]; n++) - { - for (unsigned int h = 0; h < inputInfo.GetShape()[dataLayout.GetHeightIndex()]; h++) - { - for (unsigned int w = 0; w < inputInfo.GetShape()[dataLayout.GetWidthIndex()]; w++) - { - output.Get(n, c, h, w) = mult * input.Get(n, c, h, w) + add; - } - } - } - } -} +void BatchNormImpl(const BatchNormalizationQueueDescriptor& data, + Decoder& meanIn, + Decoder& varIn, + Decoder& betaIn, + Decoder& gammaIn, + Decoder& inputData, + Encoder& outputData); -} //namespace armnn +} // namespace armnn diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index df126c4308..cdca22da31 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -7,6 +7,7 @@ list(APPEND armnnRefBackendWorkloads_sources Activation.cpp Activation.hpp BaseIterator.hpp + BatchNormImpl.cpp BatchNormImpl.hpp BatchToSpaceNd.cpp BatchToSpaceNd.hpp @@ -37,10 +38,8 @@ list(APPEND armnnRefBackendWorkloads_sources Pooling2d.hpp RefActivationWorkload.cpp RefActivationWorkload.hpp - RefBatchNormalizationFloat32Workload.cpp - RefBatchNormalizationFloat32Workload.hpp - RefBatchNormalizationUint8Workload.cpp - RefBatchNormalizationUint8Workload.hpp + RefBatchNormalizationWorkload.cpp + RefBatchNormalizationWorkload.hpp RefBatchToSpaceNdFloat32Workload.cpp RefBatchToSpaceNdFloat32Workload.hpp RefBatchToSpaceNdUint8Workload.cpp diff --git a/src/backends/reference/workloads/Encoders.hpp b/src/backends/reference/workloads/Encoders.hpp index 547bead98a..af3b937c2a 100644 --- a/src/backends/reference/workloads/Encoders.hpp +++ b/src/backends/reference/workloads/Encoders.hpp @@ -7,6 +7,8 @@ #include "BaseIterator.hpp" +#include + namespace armnn { diff --git a/src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.cpp b/src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.cpp deleted file mode 100644 index 313af9c438..0000000000 --- a/src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.cpp +++ /dev/null @@ -1,38 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "RefBatchNormalizationFloat32Workload.hpp" - -#include "BatchNormImpl.hpp" -#include "RefWorkloadUtils.hpp" - -#include "Profiling.hpp" - -namespace armnn -{ -RefBatchNormalizationFloat32Workload::RefBatchNormalizationFloat32Workload( - const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) - : Float32Workload(descriptor, info), - m_Mean(std::make_unique(*(descriptor.m_Mean))), - m_Variance(std::make_unique(*(descriptor.m_Variance))), - m_Beta(std::make_unique(*(descriptor.m_Beta))), - m_Gamma(std::make_unique(*(descriptor.m_Gamma))) {} - -void RefBatchNormalizationFloat32Workload::Execute() const -{ - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefBatchNormalizationFloat32Workload_Execute"); - - const float* var = m_Variance->GetConstTensor(); - const float* mean = m_Mean->GetConstTensor(); - const float* gamma = m_Gamma->GetConstTensor(); - const float* beta = m_Beta->GetConstTensor(); - - auto inputData = GetInputTensorDataFloat(0, m_Data); - auto outputData = GetOutputTensorDataFloat(0, m_Data); - - BatchNormImpl(m_Data, var, mean, gamma, beta, outputData, inputData); -} - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.hpp b/src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.hpp deleted file mode 100644 index 9f92899f4f..0000000000 --- a/src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.hpp +++ /dev/null @@ -1,28 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include -#include - -namespace armnn -{ - -class RefBatchNormalizationFloat32Workload : public Float32Workload -{ -public: - explicit RefBatchNormalizationFloat32Workload(const BatchNormalizationQueueDescriptor& descriptor, - const WorkloadInfo& info); - virtual void Execute() const override; - -private: - std::unique_ptr m_Mean; - std::unique_ptr m_Variance; - std::unique_ptr m_Beta; - std::unique_ptr m_Gamma; -}; - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefBatchNormalizationUint8Workload.cpp b/src/backends/reference/workloads/RefBatchNormalizationUint8Workload.cpp deleted file mode 100644 index e248ad4b9d..0000000000 --- a/src/backends/reference/workloads/RefBatchNormalizationUint8Workload.cpp +++ /dev/null @@ -1,47 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "RefBatchNormalizationUint8Workload.hpp" - -#include "BatchNormImpl.hpp" -#include "RefWorkloadUtils.hpp" - -#include "Profiling.hpp" - -#include - -namespace armnn -{ -RefBatchNormalizationUint8Workload::RefBatchNormalizationUint8Workload( - const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) - : Uint8Workload(descriptor, info), - m_Mean(std::make_unique(*(descriptor.m_Mean))), - m_Variance(std::make_unique(*(descriptor.m_Variance))), - m_Beta(std::make_unique(*(descriptor.m_Beta))), - m_Gamma(std::make_unique(*(descriptor.m_Gamma))) {} - -void RefBatchNormalizationUint8Workload::Execute() const -{ - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefBatchNormalizationUint8Workload_Execute"); - - const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); - const TensorInfo& varInfo = GetTensorInfo(m_Variance.get()); - const TensorInfo& meanInfo = GetTensorInfo(m_Mean.get()); - const TensorInfo& gammaInfo = GetTensorInfo(m_Gamma.get()); - const TensorInfo& betaInfo = GetTensorInfo(m_Beta.get()); - const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); - - auto input = Dequantize(GetInputTensorDataU8(0, m_Data), inputInfo0); - auto var = Dequantize(m_Variance->GetConstTensor(), varInfo); - auto mean = Dequantize(m_Mean->GetConstTensor(), meanInfo); - auto gamma = Dequantize(m_Gamma->GetConstTensor(), gammaInfo); - auto beta = Dequantize(m_Beta->GetConstTensor(), betaInfo); - - std::vector results(outputInfo.GetNumElements()); - BatchNormImpl(m_Data, var.data(), mean.data(), gamma.data(), beta.data(), results.data(), input.data()); - Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo); -} - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefBatchNormalizationUint8Workload.hpp b/src/backends/reference/workloads/RefBatchNormalizationUint8Workload.hpp deleted file mode 100644 index 7c288a5c07..0000000000 --- a/src/backends/reference/workloads/RefBatchNormalizationUint8Workload.hpp +++ /dev/null @@ -1,28 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include -#include - -namespace armnn -{ - -class RefBatchNormalizationUint8Workload : public Uint8Workload -{ -public: - explicit RefBatchNormalizationUint8Workload(const BatchNormalizationQueueDescriptor& descriptor, - const WorkloadInfo& info); - virtual void Execute() const override; - -private: - std::unique_ptr m_Mean; - std::unique_ptr m_Variance; - std::unique_ptr m_Beta; - std::unique_ptr m_Gamma; -}; - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefBatchNormalizationWorkload.cpp b/src/backends/reference/workloads/RefBatchNormalizationWorkload.cpp new file mode 100644 index 0000000000..b43b104459 --- /dev/null +++ b/src/backends/reference/workloads/RefBatchNormalizationWorkload.cpp @@ -0,0 +1,45 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefBatchNormalizationWorkload.hpp" + +#include "BatchNormImpl.hpp" +#include "RefWorkloadUtils.hpp" + +#include "Profiling.hpp" + +namespace armnn +{ + +RefBatchNormalizationWorkload::RefBatchNormalizationWorkload(const BatchNormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload(descriptor, info) + , m_Mean (std::make_unique(*(descriptor.m_Mean))) + , m_Variance(std::make_unique(*(descriptor.m_Variance))) + , m_Beta (std::make_unique(*(descriptor.m_Beta))) + , m_Gamma (std::make_unique(*(descriptor.m_Gamma))) +{} + +void RefBatchNormalizationWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefBatchNormalizationWorkload_Execute"); + + std::unique_ptr> meanDecoder = MakeDecoder(GetTensorInfo(m_Mean.get()), + m_Mean.get()->Map(true)); + std::unique_ptr> varianceDecoder = MakeDecoder(GetTensorInfo(m_Variance.get()), + m_Variance.get()->Map(true)); + std::unique_ptr> gammaDecoder = MakeDecoder(GetTensorInfo(m_Gamma.get()), + m_Gamma.get()->Map(true)); + std::unique_ptr> betaDecoder = MakeDecoder(GetTensorInfo(m_Beta.get()), + m_Beta.get()->Map(true)); + std::unique_ptr> inputDecoder = MakeDecoder(GetTensorInfo(m_Data.m_Inputs[0]), + m_Data.m_Inputs[0]->Map()); + std::unique_ptr> outputEncoder = MakeEncoder(GetTensorInfo(m_Data.m_Outputs[0]), + m_Data.m_Outputs[0]->Map()); + + BatchNormImpl(m_Data, *meanDecoder, *varianceDecoder, *betaDecoder, *gammaDecoder, *inputDecoder, *outputEncoder); +} + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefBatchNormalizationWorkload.hpp b/src/backends/reference/workloads/RefBatchNormalizationWorkload.hpp new file mode 100644 index 0000000000..9e71e7b4c5 --- /dev/null +++ b/src/backends/reference/workloads/RefBatchNormalizationWorkload.hpp @@ -0,0 +1,28 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include +#include + +namespace armnn +{ + +class RefBatchNormalizationWorkload : public BaseWorkload +{ +public: + explicit RefBatchNormalizationWorkload(const BatchNormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info); + virtual void Execute() const override; + +private: + std::unique_ptr m_Mean; + std::unique_ptr m_Variance; + std::unique_ptr m_Beta; + std::unique_ptr m_Gamma; +}; + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index c8c26b0b83..7ccd4efc54 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -21,7 +21,7 @@ #include "RefGatherWorkload.hpp" #include "Softmax.hpp" #include "TensorBufferArrayView.hpp" -#include "RefBatchNormalizationFloat32Workload.hpp" +#include "RefBatchNormalizationWorkload.hpp" #include "Splitter.hpp" #include "RefDepthwiseConvolution2dWorkload.hpp" #include "FullyConnected.hpp" @@ -29,7 +29,6 @@ #include "RefFloorWorkload.hpp" #include "RefSoftmaxWorkload.hpp" #include "RefResizeBilinearFloat32Workload.hpp" -#include "RefBatchNormalizationUint8Workload.hpp" #include "ResizeBilinear.hpp" #include "RefNormalizationFloat32Workload.hpp" #include "RefDetectionPostProcessFloat32Workload.hpp" -- cgit v1.2.1