diff options
Diffstat (limited to 'src/backends/reference/workloads')
6 files changed, 166 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index c2eb025789..9a5f427d37 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -35,6 +35,8 @@ list(APPEND armnnRefBackendWorkloads_sources FullyConnected.hpp Gather.cpp Gather.hpp + InstanceNorm.cpp + InstanceNorm.hpp LstmUtils.hpp LstmUtils.cpp Maximum.hpp @@ -89,6 +91,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefFullyConnectedWorkload.hpp RefGatherWorkload.cpp RefGatherWorkload.hpp + RefInstanceNormalizationWorkload.cpp + RefInstanceNormalizationWorkload.hpp RefL2NormalizationWorkload.cpp RefL2NormalizationWorkload.hpp RefLstmWorkload.cpp diff --git a/src/backends/reference/workloads/InstanceNorm.cpp b/src/backends/reference/workloads/InstanceNorm.cpp new file mode 100644 index 0000000000..9d6532fa6e --- /dev/null +++ b/src/backends/reference/workloads/InstanceNorm.cpp @@ -0,0 +1,86 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "InstanceNorm.hpp" +#include "RefWorkloadUtils.hpp" + +#include <armnn/Tensor.hpp> + +#include <DataLayoutIndexed.hpp> + +#include <cmath> + +namespace armnn +{ + +void InstanceNorm(const InstanceNormalizationQueueDescriptor& data, + Decoder<float>& inputDecoder, + Encoder<float>& outputEncoder) +{ + const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[0]); + const TensorShape inputShape = inputInfo.GetShape(); + + armnnUtils::DataLayoutIndexed dataLayout(data.m_Parameters.m_DataLayout); + + unsigned int inputBatches = inputShape[0]; + unsigned int inputHeight = inputShape[dataLayout.GetHeightIndex()]; + unsigned int inputWidth = inputShape[dataLayout.GetWidthIndex()]; + unsigned int inputChannels = inputShape[dataLayout.GetChannelsIndex()]; + + float beta = data.m_Parameters.m_Beta; + float eps = data.m_Parameters.m_Eps; + float gamma = data.m_Parameters.m_Gamma; + + for (unsigned int n = 0; n < inputBatches; ++n) + { + for (unsigned int c = 0; c < inputChannels; ++c) + { + float mean = 0, var = 0; + + //Calculate Mean + for (unsigned int h = 0; h < inputHeight; h++) + { + for (unsigned int w = 0; w < inputWidth; w++) + { + unsigned int index = dataLayout.GetIndex(inputShape, n, c, h, w); + + inputDecoder[index]; + float value = inputDecoder.Get(); + mean += value; + } + } + mean /= static_cast<float>(inputHeight * inputWidth); + + //Calculate Variance + for (unsigned int h = 0; h < inputHeight; h++) + { + for (unsigned int w = 0; w < inputWidth; w++) + { + unsigned int index = dataLayout.GetIndex(inputShape, n, c, h, w); + + inputDecoder[index]; + float value = inputDecoder.Get(); + var += (value - mean) * (value - mean); + } + } + var /= static_cast<float>(inputHeight * inputWidth); + + // Apply Instance Normalisation + for (unsigned int h = 0; h < inputHeight; ++h) + { + for (unsigned int w = 0; w < inputWidth; ++w) + { + unsigned int index = dataLayout.GetIndex(inputShape, n, c, h, w); + inputDecoder[index]; + outputEncoder[index]; + outputEncoder.Set((inputDecoder.Get() - mean) * gamma / std::sqrt ( var + eps) + beta); + } + + } + } + } +} + +} // namespace armnn diff --git a/src/backends/reference/workloads/InstanceNorm.hpp b/src/backends/reference/workloads/InstanceNorm.hpp new file mode 100644 index 0000000000..d73b4cd115 --- /dev/null +++ b/src/backends/reference/workloads/InstanceNorm.hpp @@ -0,0 +1,20 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "Encoders.hpp" +#include "Decoders.hpp" + +#include <backendsCommon/WorkloadData.hpp> + +namespace armnn +{ + +void InstanceNorm(const InstanceNormalizationQueueDescriptor& data, + Decoder<float>& inputData, + Encoder<float>& outputData); + +} // namespace armnn diff --git a/src/backends/reference/workloads/RefInstanceNormalizationWorkload.cpp b/src/backends/reference/workloads/RefInstanceNormalizationWorkload.cpp new file mode 100644 index 0000000000..875d11a00d --- /dev/null +++ b/src/backends/reference/workloads/RefInstanceNormalizationWorkload.cpp @@ -0,0 +1,33 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefInstanceNormalizationWorkload.hpp" + +#include "InstanceNorm.hpp" +#include "RefWorkloadUtils.hpp" + +#include "Profiling.hpp" + +namespace armnn +{ + +RefInstanceNormalizationWorkload::RefInstanceNormalizationWorkload( + const InstanceNormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload<InstanceNormalizationQueueDescriptor>(descriptor, info) {} + +void RefInstanceNormalizationWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefInstanceNormalizationWorkload_Execute"); + + std::unique_ptr<Decoder<float>> inputDecoder = MakeDecoder<float>(GetTensorInfo(m_Data.m_Inputs[0]), + m_Data.m_Inputs[0]->Map()); + std::unique_ptr<Encoder<float>> outputEncoder = MakeEncoder<float>(GetTensorInfo(m_Data.m_Outputs[0]), + m_Data.m_Outputs[0]->Map()); + + InstanceNorm(m_Data, *inputDecoder, *outputEncoder); +} + +} // namespace armnn diff --git a/src/backends/reference/workloads/RefInstanceNormalizationWorkload.hpp b/src/backends/reference/workloads/RefInstanceNormalizationWorkload.hpp new file mode 100644 index 0000000000..3d8a72c361 --- /dev/null +++ b/src/backends/reference/workloads/RefInstanceNormalizationWorkload.hpp @@ -0,0 +1,22 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <backendsCommon/Workload.hpp> +#include <backendsCommon/WorkloadData.hpp> + +namespace armnn +{ + +class RefInstanceNormalizationWorkload : public BaseWorkload<InstanceNormalizationQueueDescriptor> +{ +public: + explicit RefInstanceNormalizationWorkload(const InstanceNormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info); + virtual void Execute() const override; +}; + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 94592cb53e..39dfa0517b 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -36,6 +36,7 @@ #include "RefFloorWorkload.hpp" #include "RefFakeQuantizationFloat32Workload.hpp" #include "RefGatherWorkload.hpp" +#include "RefInstanceNormalizationWorkload.hpp" #include "RefL2NormalizationWorkload.hpp" #include "RefLstmWorkload.hpp" #include "RefMeanWorkload.hpp" |