From 09ca49cdcfbe377da979a19df9bcdb7cbffc7b50 Mon Sep 17 00:00:00 2001 From: Kevin May Date: Wed, 9 Oct 2019 12:37:34 +0100 Subject: IVGCVSW-3888 Add INSTANCE_NORMALIZATION Reference implementation Signed-off-by: Kevin May Change-Id: I725022f86e990c482ea323fc90fd136fe493ed68 --- src/backends/reference/RefLayerSupport.cpp | 31 ++++++++ src/backends/reference/RefLayerSupport.hpp | 5 ++ src/backends/reference/RefWorkloadFactory.cpp | 6 ++ src/backends/reference/RefWorkloadFactory.hpp | 3 + src/backends/reference/backend.mk | 2 + src/backends/reference/test/RefLayerTests.cpp | 13 ++++ src/backends/reference/workloads/CMakeLists.txt | 4 + src/backends/reference/workloads/InstanceNorm.cpp | 86 ++++++++++++++++++++++ src/backends/reference/workloads/InstanceNorm.hpp | 20 +++++ .../workloads/RefInstanceNormalizationWorkload.cpp | 33 +++++++++ .../workloads/RefInstanceNormalizationWorkload.hpp | 22 ++++++ src/backends/reference/workloads/RefWorkloads.hpp | 1 + 12 files changed, 226 insertions(+) create mode 100644 src/backends/reference/workloads/InstanceNorm.cpp create mode 100644 src/backends/reference/workloads/InstanceNorm.hpp create mode 100644 src/backends/reference/workloads/RefInstanceNormalizationWorkload.cpp create mode 100644 src/backends/reference/workloads/RefInstanceNormalizationWorkload.hpp (limited to 'src') diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 06da77603d..0d6b16cdf8 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -833,6 +833,37 @@ bool RefLayerSupport::IsInputSupported(const TensorInfo& input, return true; } +bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input, + const TensorInfo& output, + const InstanceNormalizationDescriptor& descriptor, + Optional reasonIfUnsupported) const +{ + ignore_unused(descriptor); + // Define supported types + std::array supportedTypes = + { + DataType::Float32, + DataType::Float16 + }; + + bool supported = true; + + supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, + "Reference Instance Normalization: input type not supported."); + + supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, + "Reference Instance Normalization: output type not supported."); + + supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, + "Reference Instance Normalization: input and output types mismatched."); + + supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported, + "Reference Instance Normalization: input and output shapes have different " + "num total elements."); + + return supported; +} + bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input, const TensorInfo& output, const L2NormalizationDescriptor& descriptor, diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index cc9478d871..36080f7da4 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -139,6 +139,11 @@ public: bool IsInputSupported(const TensorInfo& input, Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsInstanceNormalizationSupported(const TensorInfo& input, + const TensorInfo& output, + const InstanceNormalizationDescriptor& descriptor, + Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsL2NormalizationSupported(const TensorInfo& input, const TensorInfo& output, const L2NormalizationDescriptor& descriptor, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 254b221cc8..8c082749a4 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -481,4 +481,10 @@ std::unique_ptr RefWorkloadFactory::CreateSlice(const SliceQueueDescr return std::make_unique(descriptor, info); } +std::unique_ptr RefWorkloadFactory::CreateInstanceNormalization( + const InstanceNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const +{ + return std::make_unique(descriptor, info); +} + } // namespace armnn diff --git a/src/backends/reference/RefWorkloadFactory.hpp b/src/backends/reference/RefWorkloadFactory.hpp index e8e11e027e..0a1fab127c 100644 --- a/src/backends/reference/RefWorkloadFactory.hpp +++ b/src/backends/reference/RefWorkloadFactory.hpp @@ -223,6 +223,9 @@ public: std::unique_ptr CreateSlice(const SliceQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + std::unique_ptr CreateInstanceNormalization(const InstanceNormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; + private: template diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index 597fba8d7d..f45b01549a 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -34,6 +34,7 @@ BACKEND_SOURCES := \ workloads/ElementwiseFunction.cpp \ workloads/FullyConnected.cpp \ workloads/Gather.cpp \ + workloads/InstanceNorm.cpp \ workloads/LstmUtils.cpp \ workloads/Mean.cpp \ workloads/Concatenate.cpp \ @@ -60,6 +61,7 @@ BACKEND_SOURCES := \ workloads/RefFloorWorkload.cpp \ workloads/RefFullyConnectedWorkload.cpp \ workloads/RefGatherWorkload.cpp \ + workloads/RefInstanceNormalizationWorkload.cpp \ workloads/RefL2NormalizationWorkload.cpp \ workloads/RefLstmWorkload.cpp \ workloads/RefMeanWorkload.cpp \ diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index 0058e15a8e..cef3a800ac 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -345,6 +345,19 @@ ARMNN_AUTO_TEST_CASE(ConstantLinearActivation, ConstantLinearActivationTest) ARMNN_AUTO_TEST_CASE(ConstantLinearActivationUint8, ConstantLinearActivationUint8Test) ARMNN_AUTO_TEST_CASE(ConstantLinearActivationInt16, ConstantLinearActivationInt16Test) +// InstanceNormalization +ARMNN_AUTO_TEST_CASE(InstanceNormFloat32Nchw, InstanceNormFloat32Test, DataLayout::NCHW); +ARMNN_AUTO_TEST_CASE(InstanceNormFloat16Nchw, InstanceNormFloat16Test, DataLayout::NCHW); + +ARMNN_AUTO_TEST_CASE(InstanceNormFloat32Nhwc, InstanceNormFloat32Test, DataLayout::NHWC); +ARMNN_AUTO_TEST_CASE(InstanceNormFloat16Nhwc, InstanceNormFloat16Test, DataLayout::NHWC); + +ARMNN_AUTO_TEST_CASE(InstanceNormFloat32Nchw2, InstanceNormFloat32Test2, DataLayout::NCHW); +ARMNN_AUTO_TEST_CASE(InstanceNormFloat16Nchw2, InstanceNormFloat16Test2, DataLayout::NCHW); + +ARMNN_AUTO_TEST_CASE(InstanceNormFloat32Nhwc2, InstanceNormFloat32Test2, DataLayout::NHWC); +ARMNN_AUTO_TEST_CASE(InstanceNormFloat16Nhwc2, InstanceNormFloat16Test2, DataLayout::NHWC); + // Normalization ARMNN_AUTO_TEST_CASE(SimpleNormalizationAcross, SimpleNormalizationAcrossTest) ARMNN_AUTO_TEST_CASE(SimpleNormalizationWithin, SimpleNormalizationWithinTest) diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index c2eb025789..9a5f427d37 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -35,6 +35,8 @@ list(APPEND armnnRefBackendWorkloads_sources FullyConnected.hpp Gather.cpp Gather.hpp + InstanceNorm.cpp + InstanceNorm.hpp LstmUtils.hpp LstmUtils.cpp Maximum.hpp @@ -89,6 +91,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefFullyConnectedWorkload.hpp RefGatherWorkload.cpp RefGatherWorkload.hpp + RefInstanceNormalizationWorkload.cpp + RefInstanceNormalizationWorkload.hpp RefL2NormalizationWorkload.cpp RefL2NormalizationWorkload.hpp RefLstmWorkload.cpp diff --git a/src/backends/reference/workloads/InstanceNorm.cpp b/src/backends/reference/workloads/InstanceNorm.cpp new file mode 100644 index 0000000000..9d6532fa6e --- /dev/null +++ b/src/backends/reference/workloads/InstanceNorm.cpp @@ -0,0 +1,86 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "InstanceNorm.hpp" +#include "RefWorkloadUtils.hpp" + +#include + +#include + +#include + +namespace armnn +{ + +void InstanceNorm(const InstanceNormalizationQueueDescriptor& data, + Decoder& inputDecoder, + Encoder& outputEncoder) +{ + const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[0]); + const TensorShape inputShape = inputInfo.GetShape(); + + armnnUtils::DataLayoutIndexed dataLayout(data.m_Parameters.m_DataLayout); + + unsigned int inputBatches = inputShape[0]; + unsigned int inputHeight = inputShape[dataLayout.GetHeightIndex()]; + unsigned int inputWidth = inputShape[dataLayout.GetWidthIndex()]; + unsigned int inputChannels = inputShape[dataLayout.GetChannelsIndex()]; + + float beta = data.m_Parameters.m_Beta; + float eps = data.m_Parameters.m_Eps; + float gamma = data.m_Parameters.m_Gamma; + + for (unsigned int n = 0; n < inputBatches; ++n) + { + for (unsigned int c = 0; c < inputChannels; ++c) + { + float mean = 0, var = 0; + + //Calculate Mean + for (unsigned int h = 0; h < inputHeight; h++) + { + for (unsigned int w = 0; w < inputWidth; w++) + { + unsigned int index = dataLayout.GetIndex(inputShape, n, c, h, w); + + inputDecoder[index]; + float value = inputDecoder.Get(); + mean += value; + } + } + mean /= static_cast(inputHeight * inputWidth); + + //Calculate Variance + for (unsigned int h = 0; h < inputHeight; h++) + { + for (unsigned int w = 0; w < inputWidth; w++) + { + unsigned int index = dataLayout.GetIndex(inputShape, n, c, h, w); + + inputDecoder[index]; + float value = inputDecoder.Get(); + var += (value - mean) * (value - mean); + } + } + var /= static_cast(inputHeight * inputWidth); + + // Apply Instance Normalisation + for (unsigned int h = 0; h < inputHeight; ++h) + { + for (unsigned int w = 0; w < inputWidth; ++w) + { + unsigned int index = dataLayout.GetIndex(inputShape, n, c, h, w); + inputDecoder[index]; + outputEncoder[index]; + outputEncoder.Set((inputDecoder.Get() - mean) * gamma / std::sqrt ( var + eps) + beta); + } + + } + } + } +} + +} // namespace armnn diff --git a/src/backends/reference/workloads/InstanceNorm.hpp b/src/backends/reference/workloads/InstanceNorm.hpp new file mode 100644 index 0000000000..d73b4cd115 --- /dev/null +++ b/src/backends/reference/workloads/InstanceNorm.hpp @@ -0,0 +1,20 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "Encoders.hpp" +#include "Decoders.hpp" + +#include + +namespace armnn +{ + +void InstanceNorm(const InstanceNormalizationQueueDescriptor& data, + Decoder& inputData, + Encoder& outputData); + +} // namespace armnn diff --git a/src/backends/reference/workloads/RefInstanceNormalizationWorkload.cpp b/src/backends/reference/workloads/RefInstanceNormalizationWorkload.cpp new file mode 100644 index 0000000000..875d11a00d --- /dev/null +++ b/src/backends/reference/workloads/RefInstanceNormalizationWorkload.cpp @@ -0,0 +1,33 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefInstanceNormalizationWorkload.hpp" + +#include "InstanceNorm.hpp" +#include "RefWorkloadUtils.hpp" + +#include "Profiling.hpp" + +namespace armnn +{ + +RefInstanceNormalizationWorkload::RefInstanceNormalizationWorkload( + const InstanceNormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload(descriptor, info) {} + +void RefInstanceNormalizationWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefInstanceNormalizationWorkload_Execute"); + + std::unique_ptr> inputDecoder = MakeDecoder(GetTensorInfo(m_Data.m_Inputs[0]), + m_Data.m_Inputs[0]->Map()); + std::unique_ptr> outputEncoder = MakeEncoder(GetTensorInfo(m_Data.m_Outputs[0]), + m_Data.m_Outputs[0]->Map()); + + InstanceNorm(m_Data, *inputDecoder, *outputEncoder); +} + +} // namespace armnn diff --git a/src/backends/reference/workloads/RefInstanceNormalizationWorkload.hpp b/src/backends/reference/workloads/RefInstanceNormalizationWorkload.hpp new file mode 100644 index 0000000000..3d8a72c361 --- /dev/null +++ b/src/backends/reference/workloads/RefInstanceNormalizationWorkload.hpp @@ -0,0 +1,22 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include +#include + +namespace armnn +{ + +class RefInstanceNormalizationWorkload : public BaseWorkload +{ +public: + explicit RefInstanceNormalizationWorkload(const InstanceNormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info); + virtual void Execute() const override; +}; + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 94592cb53e..39dfa0517b 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -36,6 +36,7 @@ #include "RefFloorWorkload.hpp" #include "RefFakeQuantizationFloat32Workload.hpp" #include "RefGatherWorkload.hpp" +#include "RefInstanceNormalizationWorkload.hpp" #include "RefL2NormalizationWorkload.hpp" #include "RefLstmWorkload.hpp" #include "RefMeanWorkload.hpp" -- cgit v1.2.1