diff options
Diffstat (limited to 'src/backends')
15 files changed, 225 insertions, 208 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 6d17f3e042..a43619a466 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -684,28 +684,48 @@ void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInf { ValidateNumInputs(workloadInfo, "BatchNormalizationQueueDescriptor", 1); ValidateNumOutputs(workloadInfo, "BatchNormalizationQueueDescriptor", 1); + + const TensorInfo& input = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0]; + + std::vector<DataType> supportedTypes = + { + DataType::Float16, + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(input, supportedTypes, "BatchNormalizationQueueDescriptor"); + ValidateDataTypes(output, supportedTypes, "BatchNormalizationQueueDescriptor"); + + ValidateDataTypes(output, { input.GetDataType() }, "BatchNormalizationQueueDescriptor"); + + ValidateTensorQuantizationSpace(input, output, "BatchNormalizationQueueDescriptor", "input", "output"); + ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_OutputTensorInfos[0], "BatchNormalizationQueueDescriptor", "input", "output"); - ValidatePointer(m_Mean, "BatchNormalizationQueueDescriptor", "mean"); - ValidatePointer(m_Variance, "BatchNormalizationQueueDescriptor", "variance"); - ValidatePointer(m_Beta, "BatchNormalizationQueueDescriptor", "beta"); - ValidatePointer(m_Gamma, "BatchNormalizationQueueDescriptor", "gamma"); - - ValidateTensorNumDimensions(m_Mean->GetTensorInfo(), "BatchNormalizationQueueDescriptor", 1, "mean"); - ValidateTensorNumDimensions(m_Variance->GetTensorInfo(), "BatchNormalizationQueueDescriptor", 1, "variance"); - ValidateTensorNumDimensions(m_Beta->GetTensorInfo(), "BatchNormalizationQueueDescriptor", 1, "beta"); - ValidateTensorNumDimensions(m_Gamma->GetTensorInfo(), "BatchNormalizationQueueDescriptor", 1, "gamma"); - - ValidateTensorShapesMatch( - m_Mean->GetTensorInfo(), m_Variance->GetTensorInfo(), "BatchNormalizationQueueDescriptor", "mean", "variance"); - ValidateTensorShapesMatch( - m_Mean->GetTensorInfo(), m_Beta->GetTensorInfo(), "BatchNormalizationQueueDescriptor", "mean", "beta"); - ValidateTensorShapesMatch( - m_Mean->GetTensorInfo(), m_Gamma->GetTensorInfo(), "BatchNormalizationQueueDescriptor", "mean", "gamma"); + ValidatePointer(m_Mean, "BatchNormalizationQueueDescriptor", "mean"); + ValidatePointer(m_Variance, "BatchNormalizationQueueDescriptor", "variance"); + ValidatePointer(m_Beta, "BatchNormalizationQueueDescriptor", "beta"); + ValidatePointer(m_Gamma, "BatchNormalizationQueueDescriptor", "gamma"); + + const TensorInfo& mean = m_Mean->GetTensorInfo(); + const TensorInfo& variance = m_Variance->GetTensorInfo(); + const TensorInfo& beta = m_Beta->GetTensorInfo(); + const TensorInfo& gamma = m_Gamma->GetTensorInfo(); + + ValidateTensorNumDimensions(mean, "BatchNormalizationQueueDescriptor", 1, "mean"); + ValidateTensorNumDimensions(variance, "BatchNormalizationQueueDescriptor", 1, "variance"); + ValidateTensorNumDimensions(beta, "BatchNormalizationQueueDescriptor", 1, "beta"); + ValidateTensorNumDimensions(gamma, "BatchNormalizationQueueDescriptor", 1, "gamma"); + + ValidateTensorShapesMatch(mean, variance, "BatchNormalizationQueueDescriptor", "mean", "variance"); + ValidateTensorShapesMatch(mean, beta, "BatchNormalizationQueueDescriptor", "mean", "beta"); + ValidateTensorShapesMatch(mean, gamma, "BatchNormalizationQueueDescriptor", "mean", "gamma"); } void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index edd552b2ac..aeff51d853 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -280,22 +280,44 @@ bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0, bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input, const TensorInfo& output, const TensorInfo& mean, - const TensorInfo& var, + const TensorInfo& variance, const TensorInfo& beta, const TensorInfo& gamma, const BatchNormalizationDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported) const { - ignore_unused(output); - ignore_unused(mean); - ignore_unused(var); - ignore_unused(beta); - ignore_unused(gamma); ignore_unused(descriptor); - return IsSupportedForDataTypeRef(reasonIfUnsupported, - input.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); + + std::array<DataType, 2> supportedTypes = + { + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + bool supported = true; + + supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, + "Reference batch normalization: input is not a supported type."); + + supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, + "Reference batch normalization: output is not a supported type."); + + supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, + "Reference batch normalization: input and output types are mismatched"); + + supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported, + "Reference batch normalization: mean is not a supported type."); + + supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported, + "Reference batch normalization: variance is not a supported type."); + + supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported, + "Reference batch normalization: beta is not a supported type."); + + supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported, + "Reference batch normalization: gamma is not a supported type."); + + return supported; } bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 161065550d..d103f56c23 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -220,7 +220,7 @@ std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMultiplication( std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateBatchNormalization( const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<RefBatchNormalizationFloat32Workload, RefBatchNormalizationUint8Workload>(descriptor, info); + return std::make_unique<RefBatchNormalizationWorkload>(descriptor, info); } std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor, diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index f371c8bd0b..81b6de18e4 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -12,6 +12,7 @@ BACKEND_SOURCES := \ RefLayerSupport.cpp \ RefWorkloadFactory.cpp \ workloads/Activation.cpp \ + workloads/BatchNormImpl.cpp \ workloads/BatchToSpaceNd.cpp \ workloads/Broadcast.cpp \ workloads/ConvImpl.cpp \ @@ -25,8 +26,7 @@ BACKEND_SOURCES := \ workloads/Pad.cpp \ workloads/Pooling2d.cpp \ workloads/RefActivationWorkload.cpp \ - workloads/RefBatchNormalizationFloat32Workload.cpp \ - workloads/RefBatchNormalizationUint8Workload.cpp \ + workloads/RefBatchNormalizationWorkload.cpp \ workloads/RefBatchToSpaceNdFloat32Workload.cpp \ workloads/RefBatchToSpaceNdUint8Workload.cpp \ workloads/RefConcatWorkload.cpp \ diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index 7c5712b915..a0c614564d 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -181,8 +181,9 @@ static void RefCreateBatchNormalizationWorkloadTest(DataLayout dataLayout) { Graph graph; RefWorkloadFactory factory; - auto workload = - CreateBatchNormalizationWorkloadTest<BatchNormalizationWorkloadType, DataType>(factory, graph, dataLayout); + auto workload = CreateBatchNormalizationWorkloadTest<BatchNormalizationWorkloadType, DataType>(factory, + graph, + dataLayout); TensorShape inputShape; TensorShape outputShape; @@ -206,25 +207,25 @@ static void RefCreateBatchNormalizationWorkloadTest(DataLayout dataLayout) BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloat32Workload) { - RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationFloat32Workload,armnn::DataType::Float32> + RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload,armnn::DataType::Float32> (DataLayout::NCHW); } BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloat32WorkloadNhwc) { - RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationFloat32Workload, armnn::DataType::Float32> + RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::Float32> (DataLayout::NHWC); } BOOST_AUTO_TEST_CASE(CreateBatchNormalizationUint8Workload) { - RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationUint8Workload, armnn::DataType::QuantisedAsymm8> + RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::QuantisedAsymm8> (DataLayout::NCHW); } BOOST_AUTO_TEST_CASE(CreateBatchNormalizationUint8WorkloadNhwc) { - RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationUint8Workload, armnn::DataType::QuantisedAsymm8> + RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::QuantisedAsymm8> (DataLayout::NHWC); } diff --git a/src/backends/reference/workloads/BatchNormImpl.cpp b/src/backends/reference/workloads/BatchNormImpl.cpp new file mode 100644 index 0000000000..36e96d3fec --- /dev/null +++ b/src/backends/reference/workloads/BatchNormImpl.cpp @@ -0,0 +1,82 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "BatchNormImpl.hpp" +#include "RefWorkloadUtils.hpp" + +#include <armnn/Tensor.hpp> + +#include <DataLayoutIndexed.hpp> + +#include <cmath> + +namespace armnn +{ + +void BatchNormImpl(const BatchNormalizationQueueDescriptor& data, + Decoder<float>& meanDecoder, + Decoder<float>& varianceDecoder, + Decoder<float>& betaDecoder, + Decoder<float>& gammaDecoder, + Decoder<float>& inputDecoder, + Encoder<float>& outputEncoder) +{ + const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[0]); + const TensorShape inputShape = inputInfo.GetShape(); + + armnnUtils::DataLayoutIndexed dataLayout(data.m_Parameters.m_DataLayout); + + unsigned int inputBatches = inputShape[0]; + unsigned int inputHeight = inputShape[dataLayout.GetHeightIndex()]; + unsigned int inputWidth = inputShape[dataLayout.GetWidthIndex()]; + unsigned int inputChannels = inputShape[dataLayout.GetChannelsIndex()]; + + for (unsigned int c = 0; c < inputChannels; c++) + { + meanDecoder[c]; + varianceDecoder[c]; + betaDecoder[c]; + gammaDecoder[c]; + float mean = meanDecoder.Get(); + float var = varianceDecoder.Get(); + float beta = betaDecoder.Get(); + float gamma = gammaDecoder.Get(); + + float mult = gamma / sqrtf(var + data.m_Parameters.m_Eps); + float add = beta - mult * mean; + + for (unsigned int n = 0; n < inputBatches; n++) + { + for (unsigned int h = 0; h < inputHeight; h++) + { + for (unsigned int w = 0; w < inputWidth; w++) + { + unsigned int index = 0; + + if (dataLayout == DataLayout::NHWC) + { + index = n * inputHeight * inputWidth * inputChannels + + h * inputWidth * inputChannels + + w * inputChannels + + c; + } + else // dataLayout == DataLayout::NCHW + { + index = n * inputHeight * inputWidth * inputChannels + + c * inputHeight * inputWidth + + h * inputWidth + + w; + } + + inputDecoder[index]; + outputEncoder[index]; + outputEncoder.Set(mult * inputDecoder.Get() + add); + } + } + } + } +} + +} // namespace armnn diff --git a/src/backends/reference/workloads/BatchNormImpl.hpp b/src/backends/reference/workloads/BatchNormImpl.hpp index 799e7a327b..c0250b9e0f 100644 --- a/src/backends/reference/workloads/BatchNormImpl.hpp +++ b/src/backends/reference/workloads/BatchNormImpl.hpp @@ -5,60 +5,20 @@ #pragma once -#include "RefWorkloadUtils.hpp" -#include "TensorBufferArrayView.hpp" +#include "Encoders.hpp" +#include "Decoders.hpp" -#include <armnn/Tensor.hpp> - -#include <DataLayoutIndexed.hpp> - -#include <cmath> +#include <backendsCommon/WorkloadData.hpp> namespace armnn { -template<typename NormData> -static void BatchNormImpl(NormData data, - const float* varIn, - const float* meanIn, - const float* gammaIn, - const float* betaIn, - float* outputData, - const float* inputData) -{ - const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[0]); - const TensorInfo& outputInfo = GetTensorInfo(data.m_Outputs[0]); - - TensorBufferArrayView<const float> input(inputInfo.GetShape(), - inputData, - data.m_Parameters.m_DataLayout); - TensorBufferArrayView<float> output(outputInfo.GetShape(), - outputData, - data.m_Parameters.m_DataLayout); - - armnnUtils::DataLayoutIndexed dataLayout(data.m_Parameters.m_DataLayout); - - for (unsigned int c = 0; c < inputInfo.GetShape()[dataLayout.GetChannelsIndex()]; c++) - { - float var = varIn[c]; - float mean = meanIn[c]; - float gamma = gammaIn[c]; - float beta = betaIn[c]; - - float mult = gamma / sqrtf(var + data.m_Parameters.m_Eps); - float add = beta - mult * mean; - - for (unsigned int n = 0; n < inputInfo.GetShape()[0]; n++) - { - for (unsigned int h = 0; h < inputInfo.GetShape()[dataLayout.GetHeightIndex()]; h++) - { - for (unsigned int w = 0; w < inputInfo.GetShape()[dataLayout.GetWidthIndex()]; w++) - { - output.Get(n, c, h, w) = mult * input.Get(n, c, h, w) + add; - } - } - } - } -} +void BatchNormImpl(const BatchNormalizationQueueDescriptor& data, + Decoder<float>& meanIn, + Decoder<float>& varIn, + Decoder<float>& betaIn, + Decoder<float>& gammaIn, + Decoder<float>& inputData, + Encoder<float>& outputData); -} //namespace armnn +} // namespace armnn diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index df126c4308..cdca22da31 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -7,6 +7,7 @@ list(APPEND armnnRefBackendWorkloads_sources Activation.cpp Activation.hpp BaseIterator.hpp + BatchNormImpl.cpp BatchNormImpl.hpp BatchToSpaceNd.cpp BatchToSpaceNd.hpp @@ -37,10 +38,8 @@ list(APPEND armnnRefBackendWorkloads_sources Pooling2d.hpp RefActivationWorkload.cpp RefActivationWorkload.hpp - RefBatchNormalizationFloat32Workload.cpp - RefBatchNormalizationFloat32Workload.hpp - RefBatchNormalizationUint8Workload.cpp - RefBatchNormalizationUint8Workload.hpp + RefBatchNormalizationWorkload.cpp + RefBatchNormalizationWorkload.hpp RefBatchToSpaceNdFloat32Workload.cpp RefBatchToSpaceNdFloat32Workload.hpp RefBatchToSpaceNdUint8Workload.cpp diff --git a/src/backends/reference/workloads/Encoders.hpp b/src/backends/reference/workloads/Encoders.hpp index 547bead98a..af3b937c2a 100644 --- a/src/backends/reference/workloads/Encoders.hpp +++ b/src/backends/reference/workloads/Encoders.hpp @@ -7,6 +7,8 @@ #include "BaseIterator.hpp" +#include <boost/assert.hpp> + namespace armnn { diff --git a/src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.cpp b/src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.cpp deleted file mode 100644 index 313af9c438..0000000000 --- a/src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.cpp +++ /dev/null @@ -1,38 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "RefBatchNormalizationFloat32Workload.hpp" - -#include "BatchNormImpl.hpp" -#include "RefWorkloadUtils.hpp" - -#include "Profiling.hpp" - -namespace armnn -{ -RefBatchNormalizationFloat32Workload::RefBatchNormalizationFloat32Workload( - const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) - : Float32Workload<BatchNormalizationQueueDescriptor>(descriptor, info), - m_Mean(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Mean))), - m_Variance(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Variance))), - m_Beta(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Beta))), - m_Gamma(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Gamma))) {} - -void RefBatchNormalizationFloat32Workload::Execute() const -{ - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefBatchNormalizationFloat32Workload_Execute"); - - const float* var = m_Variance->GetConstTensor<float>(); - const float* mean = m_Mean->GetConstTensor<float>(); - const float* gamma = m_Gamma->GetConstTensor<float>(); - const float* beta = m_Beta->GetConstTensor<float>(); - - auto inputData = GetInputTensorDataFloat(0, m_Data); - auto outputData = GetOutputTensorDataFloat(0, m_Data); - - BatchNormImpl(m_Data, var, mean, gamma, beta, outputData, inputData); -} - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.hpp b/src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.hpp deleted file mode 100644 index 9f92899f4f..0000000000 --- a/src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.hpp +++ /dev/null @@ -1,28 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include <backendsCommon/Workload.hpp> -#include <backendsCommon/WorkloadData.hpp> - -namespace armnn -{ - -class RefBatchNormalizationFloat32Workload : public Float32Workload<BatchNormalizationQueueDescriptor> -{ -public: - explicit RefBatchNormalizationFloat32Workload(const BatchNormalizationQueueDescriptor& descriptor, - const WorkloadInfo& info); - virtual void Execute() const override; - -private: - std::unique_ptr<ScopedCpuTensorHandle> m_Mean; - std::unique_ptr<ScopedCpuTensorHandle> m_Variance; - std::unique_ptr<ScopedCpuTensorHandle> m_Beta; - std::unique_ptr<ScopedCpuTensorHandle> m_Gamma; -}; - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefBatchNormalizationUint8Workload.cpp b/src/backends/reference/workloads/RefBatchNormalizationUint8Workload.cpp deleted file mode 100644 index e248ad4b9d..0000000000 --- a/src/backends/reference/workloads/RefBatchNormalizationUint8Workload.cpp +++ /dev/null @@ -1,47 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "RefBatchNormalizationUint8Workload.hpp" - -#include "BatchNormImpl.hpp" -#include "RefWorkloadUtils.hpp" - -#include "Profiling.hpp" - -#include <vector> - -namespace armnn -{ -RefBatchNormalizationUint8Workload::RefBatchNormalizationUint8Workload( - const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) - : Uint8Workload<BatchNormalizationQueueDescriptor>(descriptor, info), - m_Mean(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Mean))), - m_Variance(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Variance))), - m_Beta(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Beta))), - m_Gamma(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Gamma))) {} - -void RefBatchNormalizationUint8Workload::Execute() const -{ - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefBatchNormalizationUint8Workload_Execute"); - - const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); - const TensorInfo& varInfo = GetTensorInfo(m_Variance.get()); - const TensorInfo& meanInfo = GetTensorInfo(m_Mean.get()); - const TensorInfo& gammaInfo = GetTensorInfo(m_Gamma.get()); - const TensorInfo& betaInfo = GetTensorInfo(m_Beta.get()); - const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); - - auto input = Dequantize(GetInputTensorDataU8(0, m_Data), inputInfo0); - auto var = Dequantize(m_Variance->GetConstTensor<uint8_t>(), varInfo); - auto mean = Dequantize(m_Mean->GetConstTensor<uint8_t>(), meanInfo); - auto gamma = Dequantize(m_Gamma->GetConstTensor<uint8_t>(), gammaInfo); - auto beta = Dequantize(m_Beta->GetConstTensor<uint8_t>(), betaInfo); - - std::vector<float> results(outputInfo.GetNumElements()); - BatchNormImpl(m_Data, var.data(), mean.data(), gamma.data(), beta.data(), results.data(), input.data()); - Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo); -} - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefBatchNormalizationWorkload.cpp b/src/backends/reference/workloads/RefBatchNormalizationWorkload.cpp new file mode 100644 index 0000000000..b43b104459 --- /dev/null +++ b/src/backends/reference/workloads/RefBatchNormalizationWorkload.cpp @@ -0,0 +1,45 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefBatchNormalizationWorkload.hpp" + +#include "BatchNormImpl.hpp" +#include "RefWorkloadUtils.hpp" + +#include "Profiling.hpp" + +namespace armnn +{ + +RefBatchNormalizationWorkload::RefBatchNormalizationWorkload(const BatchNormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload(descriptor, info) + , m_Mean (std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Mean))) + , m_Variance(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Variance))) + , m_Beta (std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Beta))) + , m_Gamma (std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Gamma))) +{} + +void RefBatchNormalizationWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefBatchNormalizationWorkload_Execute"); + + std::unique_ptr<Decoder<float>> meanDecoder = MakeDecoder<float>(GetTensorInfo(m_Mean.get()), + m_Mean.get()->Map(true)); + std::unique_ptr<Decoder<float>> varianceDecoder = MakeDecoder<float>(GetTensorInfo(m_Variance.get()), + m_Variance.get()->Map(true)); + std::unique_ptr<Decoder<float>> gammaDecoder = MakeDecoder<float>(GetTensorInfo(m_Gamma.get()), + m_Gamma.get()->Map(true)); + std::unique_ptr<Decoder<float>> betaDecoder = MakeDecoder<float>(GetTensorInfo(m_Beta.get()), + m_Beta.get()->Map(true)); + std::unique_ptr<Decoder<float>> inputDecoder = MakeDecoder<float>(GetTensorInfo(m_Data.m_Inputs[0]), + m_Data.m_Inputs[0]->Map()); + std::unique_ptr<Encoder<float>> outputEncoder = MakeEncoder<float>(GetTensorInfo(m_Data.m_Outputs[0]), + m_Data.m_Outputs[0]->Map()); + + BatchNormImpl(m_Data, *meanDecoder, *varianceDecoder, *betaDecoder, *gammaDecoder, *inputDecoder, *outputEncoder); +} + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefBatchNormalizationUint8Workload.hpp b/src/backends/reference/workloads/RefBatchNormalizationWorkload.hpp index 7c288a5c07..9e71e7b4c5 100644 --- a/src/backends/reference/workloads/RefBatchNormalizationUint8Workload.hpp +++ b/src/backends/reference/workloads/RefBatchNormalizationWorkload.hpp @@ -11,11 +11,11 @@ namespace armnn { -class RefBatchNormalizationUint8Workload : public Uint8Workload<BatchNormalizationQueueDescriptor> +class RefBatchNormalizationWorkload : public BaseWorkload<BatchNormalizationQueueDescriptor> { public: - explicit RefBatchNormalizationUint8Workload(const BatchNormalizationQueueDescriptor& descriptor, - const WorkloadInfo& info); + explicit RefBatchNormalizationWorkload(const BatchNormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info); virtual void Execute() const override; private: diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index c8c26b0b83..7ccd4efc54 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -21,7 +21,7 @@ #include "RefGatherWorkload.hpp" #include "Softmax.hpp" #include "TensorBufferArrayView.hpp" -#include "RefBatchNormalizationFloat32Workload.hpp" +#include "RefBatchNormalizationWorkload.hpp" #include "Splitter.hpp" #include "RefDepthwiseConvolution2dWorkload.hpp" #include "FullyConnected.hpp" @@ -29,7 +29,6 @@ #include "RefFloorWorkload.hpp" #include "RefSoftmaxWorkload.hpp" #include "RefResizeBilinearFloat32Workload.hpp" -#include "RefBatchNormalizationUint8Workload.hpp" #include "ResizeBilinear.hpp" #include "RefNormalizationFloat32Workload.hpp" #include "RefDetectionPostProcessFloat32Workload.hpp" |