aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin May <kevin.may@arm.com>2019-10-09 12:37:34 +0100
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>2019-10-09 16:13:36 +0000
commit09ca49cdcfbe377da979a19df9bcdb7cbffc7b50 (patch)
tree62cf0881012c80498ded1963c82654efcb761bf2
parent0d4863dd3ef0d1cafea0857e70f70b22a841ed71 (diff)
downloadarmnn-09ca49cdcfbe377da979a19df9bcdb7cbffc7b50.tar.gz
IVGCVSW-3888 Add INSTANCE_NORMALIZATION Reference implementation
Signed-off-by: Kevin May <kevin.may@arm.com> Change-Id: I725022f86e990c482ea323fc90fd136fe493ed68
-rw-r--r--src/backends/reference/RefLayerSupport.cpp31
-rw-r--r--src/backends/reference/RefLayerSupport.hpp5
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp6
-rw-r--r--src/backends/reference/RefWorkloadFactory.hpp3
-rw-r--r--src/backends/reference/backend.mk2
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp13
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt4
-rw-r--r--src/backends/reference/workloads/InstanceNorm.cpp86
-rw-r--r--src/backends/reference/workloads/InstanceNorm.hpp20
-rw-r--r--src/backends/reference/workloads/RefInstanceNormalizationWorkload.cpp33
-rw-r--r--src/backends/reference/workloads/RefInstanceNormalizationWorkload.hpp22
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp1
12 files changed, 226 insertions, 0 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 06da77603d..0d6b16cdf8 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -833,6 +833,37 @@ bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
return true;
}
+bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const InstanceNormalizationDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ ignore_unused(descriptor);
+ // Define supported types
+ std::array<DataType, 4> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::Float16
+ };
+
+ bool supported = true;
+
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+ "Reference Instance Normalization: input type not supported.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference Instance Normalization: output type not supported.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+ "Reference Instance Normalization: input and output types mismatched.");
+
+ supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
+ "Reference Instance Normalization: input and output shapes have different "
+ "num total elements.");
+
+ return supported;
+}
+
bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
const TensorInfo& output,
const L2NormalizationDescriptor& descriptor,
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index cc9478d871..36080f7da4 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -139,6 +139,11 @@ public:
bool IsInputSupported(const TensorInfo& input,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsInstanceNormalizationSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const InstanceNormalizationDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsL2NormalizationSupported(const TensorInfo& input,
const TensorInfo& output,
const L2NormalizationDescriptor& descriptor,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 254b221cc8..8c082749a4 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -481,4 +481,10 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSlice(const SliceQueueDescr
return std::make_unique<RefSliceWorkload>(descriptor, info);
}
+std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateInstanceNormalization(
+ const InstanceNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const
+{
+ return std::make_unique<RefInstanceNormalizationWorkload>(descriptor, info);
+}
+
} // namespace armnn
diff --git a/src/backends/reference/RefWorkloadFactory.hpp b/src/backends/reference/RefWorkloadFactory.hpp
index e8e11e027e..0a1fab127c 100644
--- a/src/backends/reference/RefWorkloadFactory.hpp
+++ b/src/backends/reference/RefWorkloadFactory.hpp
@@ -223,6 +223,9 @@ public:
std::unique_ptr<IWorkload> CreateSlice(const SliceQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateInstanceNormalization(const InstanceNormalizationQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
+
private:
template <typename F32Workload, typename U8Workload, typename QueueDescriptorType>
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index 597fba8d7d..f45b01549a 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -34,6 +34,7 @@ BACKEND_SOURCES := \
workloads/ElementwiseFunction.cpp \
workloads/FullyConnected.cpp \
workloads/Gather.cpp \
+ workloads/InstanceNorm.cpp \
workloads/LstmUtils.cpp \
workloads/Mean.cpp \
workloads/Concatenate.cpp \
@@ -60,6 +61,7 @@ BACKEND_SOURCES := \
workloads/RefFloorWorkload.cpp \
workloads/RefFullyConnectedWorkload.cpp \
workloads/RefGatherWorkload.cpp \
+ workloads/RefInstanceNormalizationWorkload.cpp \
workloads/RefL2NormalizationWorkload.cpp \
workloads/RefLstmWorkload.cpp \
workloads/RefMeanWorkload.cpp \
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 0058e15a8e..cef3a800ac 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -345,6 +345,19 @@ ARMNN_AUTO_TEST_CASE(ConstantLinearActivation, ConstantLinearActivationTest)
ARMNN_AUTO_TEST_CASE(ConstantLinearActivationUint8, ConstantLinearActivationUint8Test)
ARMNN_AUTO_TEST_CASE(ConstantLinearActivationInt16, ConstantLinearActivationInt16Test)
+// InstanceNormalization
+ARMNN_AUTO_TEST_CASE(InstanceNormFloat32Nchw, InstanceNormFloat32Test, DataLayout::NCHW);
+ARMNN_AUTO_TEST_CASE(InstanceNormFloat16Nchw, InstanceNormFloat16Test, DataLayout::NCHW);
+
+ARMNN_AUTO_TEST_CASE(InstanceNormFloat32Nhwc, InstanceNormFloat32Test, DataLayout::NHWC);
+ARMNN_AUTO_TEST_CASE(InstanceNormFloat16Nhwc, InstanceNormFloat16Test, DataLayout::NHWC);
+
+ARMNN_AUTO_TEST_CASE(InstanceNormFloat32Nchw2, InstanceNormFloat32Test2, DataLayout::NCHW);
+ARMNN_AUTO_TEST_CASE(InstanceNormFloat16Nchw2, InstanceNormFloat16Test2, DataLayout::NCHW);
+
+ARMNN_AUTO_TEST_CASE(InstanceNormFloat32Nhwc2, InstanceNormFloat32Test2, DataLayout::NHWC);
+ARMNN_AUTO_TEST_CASE(InstanceNormFloat16Nhwc2, InstanceNormFloat16Test2, DataLayout::NHWC);
+
// Normalization
ARMNN_AUTO_TEST_CASE(SimpleNormalizationAcross, SimpleNormalizationAcrossTest)
ARMNN_AUTO_TEST_CASE(SimpleNormalizationWithin, SimpleNormalizationWithinTest)
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index c2eb025789..9a5f427d37 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -35,6 +35,8 @@ list(APPEND armnnRefBackendWorkloads_sources
FullyConnected.hpp
Gather.cpp
Gather.hpp
+ InstanceNorm.cpp
+ InstanceNorm.hpp
LstmUtils.hpp
LstmUtils.cpp
Maximum.hpp
@@ -89,6 +91,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefFullyConnectedWorkload.hpp
RefGatherWorkload.cpp
RefGatherWorkload.hpp
+ RefInstanceNormalizationWorkload.cpp
+ RefInstanceNormalizationWorkload.hpp
RefL2NormalizationWorkload.cpp
RefL2NormalizationWorkload.hpp
RefLstmWorkload.cpp
diff --git a/src/backends/reference/workloads/InstanceNorm.cpp b/src/backends/reference/workloads/InstanceNorm.cpp
new file mode 100644
index 0000000000..9d6532fa6e
--- /dev/null
+++ b/src/backends/reference/workloads/InstanceNorm.cpp
@@ -0,0 +1,86 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "InstanceNorm.hpp"
+#include "RefWorkloadUtils.hpp"
+
+#include <armnn/Tensor.hpp>
+
+#include <DataLayoutIndexed.hpp>
+
+#include <cmath>
+
+namespace armnn
+{
+
+void InstanceNorm(const InstanceNormalizationQueueDescriptor& data,
+ Decoder<float>& inputDecoder,
+ Encoder<float>& outputEncoder)
+{
+ const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[0]);
+ const TensorShape inputShape = inputInfo.GetShape();
+
+ armnnUtils::DataLayoutIndexed dataLayout(data.m_Parameters.m_DataLayout);
+
+ unsigned int inputBatches = inputShape[0];
+ unsigned int inputHeight = inputShape[dataLayout.GetHeightIndex()];
+ unsigned int inputWidth = inputShape[dataLayout.GetWidthIndex()];
+ unsigned int inputChannels = inputShape[dataLayout.GetChannelsIndex()];
+
+ float beta = data.m_Parameters.m_Beta;
+ float eps = data.m_Parameters.m_Eps;
+ float gamma = data.m_Parameters.m_Gamma;
+
+ for (unsigned int n = 0; n < inputBatches; ++n)
+ {
+ for (unsigned int c = 0; c < inputChannels; ++c)
+ {
+ float mean = 0, var = 0;
+
+ //Calculate Mean
+ for (unsigned int h = 0; h < inputHeight; h++)
+ {
+ for (unsigned int w = 0; w < inputWidth; w++)
+ {
+ unsigned int index = dataLayout.GetIndex(inputShape, n, c, h, w);
+
+ inputDecoder[index];
+ float value = inputDecoder.Get();
+ mean += value;
+ }
+ }
+ mean /= static_cast<float>(inputHeight * inputWidth);
+
+ //Calculate Variance
+ for (unsigned int h = 0; h < inputHeight; h++)
+ {
+ for (unsigned int w = 0; w < inputWidth; w++)
+ {
+ unsigned int index = dataLayout.GetIndex(inputShape, n, c, h, w);
+
+ inputDecoder[index];
+ float value = inputDecoder.Get();
+ var += (value - mean) * (value - mean);
+ }
+ }
+ var /= static_cast<float>(inputHeight * inputWidth);
+
+ // Apply Instance Normalisation
+ for (unsigned int h = 0; h < inputHeight; ++h)
+ {
+ for (unsigned int w = 0; w < inputWidth; ++w)
+ {
+ unsigned int index = dataLayout.GetIndex(inputShape, n, c, h, w);
+ inputDecoder[index];
+ outputEncoder[index];
+ outputEncoder.Set((inputDecoder.Get() - mean) * gamma / std::sqrt ( var + eps) + beta);
+ }
+
+ }
+ }
+ }
+}
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/InstanceNorm.hpp b/src/backends/reference/workloads/InstanceNorm.hpp
new file mode 100644
index 0000000000..d73b4cd115
--- /dev/null
+++ b/src/backends/reference/workloads/InstanceNorm.hpp
@@ -0,0 +1,20 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "Encoders.hpp"
+#include "Decoders.hpp"
+
+#include <backendsCommon/WorkloadData.hpp>
+
+namespace armnn
+{
+
+void InstanceNorm(const InstanceNormalizationQueueDescriptor& data,
+ Decoder<float>& inputData,
+ Encoder<float>& outputData);
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefInstanceNormalizationWorkload.cpp b/src/backends/reference/workloads/RefInstanceNormalizationWorkload.cpp
new file mode 100644
index 0000000000..875d11a00d
--- /dev/null
+++ b/src/backends/reference/workloads/RefInstanceNormalizationWorkload.cpp
@@ -0,0 +1,33 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefInstanceNormalizationWorkload.hpp"
+
+#include "InstanceNorm.hpp"
+#include "RefWorkloadUtils.hpp"
+
+#include "Profiling.hpp"
+
+namespace armnn
+{
+
+RefInstanceNormalizationWorkload::RefInstanceNormalizationWorkload(
+ const InstanceNormalizationQueueDescriptor& descriptor,
+ const WorkloadInfo& info)
+ : BaseWorkload<InstanceNormalizationQueueDescriptor>(descriptor, info) {}
+
+void RefInstanceNormalizationWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefInstanceNormalizationWorkload_Execute");
+
+ std::unique_ptr<Decoder<float>> inputDecoder = MakeDecoder<float>(GetTensorInfo(m_Data.m_Inputs[0]),
+ m_Data.m_Inputs[0]->Map());
+ std::unique_ptr<Encoder<float>> outputEncoder = MakeEncoder<float>(GetTensorInfo(m_Data.m_Outputs[0]),
+ m_Data.m_Outputs[0]->Map());
+
+ InstanceNorm(m_Data, *inputDecoder, *outputEncoder);
+}
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefInstanceNormalizationWorkload.hpp b/src/backends/reference/workloads/RefInstanceNormalizationWorkload.hpp
new file mode 100644
index 0000000000..3d8a72c361
--- /dev/null
+++ b/src/backends/reference/workloads/RefInstanceNormalizationWorkload.hpp
@@ -0,0 +1,22 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+
+namespace armnn
+{
+
+class RefInstanceNormalizationWorkload : public BaseWorkload<InstanceNormalizationQueueDescriptor>
+{
+public:
+ explicit RefInstanceNormalizationWorkload(const InstanceNormalizationQueueDescriptor& descriptor,
+ const WorkloadInfo& info);
+ virtual void Execute() const override;
+};
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 94592cb53e..39dfa0517b 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -36,6 +36,7 @@
#include "RefFloorWorkload.hpp"
#include "RefFakeQuantizationFloat32Workload.hpp"
#include "RefGatherWorkload.hpp"
+#include "RefInstanceNormalizationWorkload.hpp"
#include "RefL2NormalizationWorkload.hpp"
#include "RefLstmWorkload.hpp"
#include "RefMeanWorkload.hpp"