From ce5045a00485f8a8c35814c0781ccbcca5678e5c Mon Sep 17 00:00:00 2001 From: Kevin May Date: Wed, 2 Oct 2019 14:07:47 +0100 Subject: IVGCVSW-3932 Add frontend for INSTANCE_NORMALIZATION Signed-off-by: Kevin May Change-Id: Ib152148ccd8d2733c617d0cf9402661fc6b71316 --- src/backends/backendsCommon/LayerSupportBase.cpp | 8 ++++ src/backends/backendsCommon/LayerSupportBase.hpp | 6 +++ src/backends/backendsCommon/WorkloadData.cpp | 46 ++++++++++++++++++++++ src/backends/backendsCommon/WorkloadData.hpp | 15 +++++++ src/backends/backendsCommon/WorkloadFactory.cpp | 22 +++++++++++ src/backends/backendsCommon/WorkloadFactory.hpp | 4 ++ .../test/IsLayerSupportedTestImpl.hpp | 2 + 7 files changed, 103 insertions(+) (limited to 'src/backends') diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp index 656407d020..c41f0b11ea 100644 --- a/src/backends/backendsCommon/LayerSupportBase.cpp +++ b/src/backends/backendsCommon/LayerSupportBase.cpp @@ -234,6 +234,14 @@ bool LayerSupportBase::IsInputSupported(const TensorInfo& input, return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } +bool LayerSupportBase::IsInstanceNormalizationSupported(const TensorInfo& input, + const TensorInfo& output, + const InstanceNormalizationDescriptor& descriptor, + Optional reasonIfUnsupported) const +{ + return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); +} + bool LayerSupportBase::IsL2NormalizationSupported(const TensorInfo& input, const TensorInfo& output, const L2NormalizationDescriptor& descriptor, diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp index c3875e6ced..495870e645 100644 --- a/src/backends/backendsCommon/LayerSupportBase.hpp +++ b/src/backends/backendsCommon/LayerSupportBase.hpp @@ -141,6 +141,12 @@ public: bool IsInputSupported(const TensorInfo& input, Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsInstanceNormalizationSupported( + const TensorInfo& input, + const TensorInfo& output, + const InstanceNormalizationDescriptor& descriptor, + Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsL2NormalizationSupported(const TensorInfo& input, const TensorInfo& output, const L2NormalizationDescriptor& descriptor, diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index e49fd09be0..aca5023f97 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1233,6 +1233,52 @@ void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) } } +void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + const std::string descriptorName{"InstanceNormalizationQueueDescriptor"}; + + ValidateNumInputs(workloadInfo, descriptorName, 1); + ValidateNumOutputs(workloadInfo, descriptorName, 1); + + const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0]; + + if (inputTensorInfo.GetNumDimensions() > 4) + { + throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported."); + } + + ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); + + // Check the supported data types + std::vector supportedTypes = + { + DataType::Float32, + DataType::Float16 + }; + + ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); + ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName); + + ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); + + ValidatePointer(m_Beta, descriptorName, "beta"); + ValidatePointer(m_Eps, descriptorName, "epsilon"); + ValidatePointer(m_Gamma, descriptorName, "gamma"); + + const TensorInfo& beta = m_Beta->GetTensorInfo(); + const TensorInfo& epsilon = m_Eps->GetTensorInfo(); + const TensorInfo& gamma = m_Gamma->GetTensorInfo(); + + ValidateTensorNumDimensions(beta, descriptorName, 1, "beta"); + ValidateTensorNumDimensions(epsilon, descriptorName, 1, "epsilon"); + ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma"); + + ValidateTensorDataTypesMatch(inputTensorInfo, beta, descriptorName, "input", "beta"); + ValidateTensorDataTypesMatch(inputTensorInfo, epsilon, descriptorName, "input", "epsilon"); + ValidateTensorDataTypesMatch(inputTensorInfo, gamma, descriptorName, "input", "gamma"); +} + void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { const std::string descriptorName{"L2NormalizationQueueDescriptor"}; diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp index 177bfb7af3..14d7b588e1 100644 --- a/src/backends/backendsCommon/WorkloadData.hpp +++ b/src/backends/backendsCommon/WorkloadData.hpp @@ -307,6 +307,21 @@ struct FakeQuantizationQueueDescriptor : QueueDescriptorWithParameters +{ + InstanceNormalizationQueueDescriptor() + : m_Beta(nullptr) + , m_Eps(nullptr) + , m_Gamma(nullptr) + { + } + + const ConstCpuTensorHandle* m_Beta; + const ConstCpuTensorHandle* m_Eps; + const ConstCpuTensorHandle* m_Gamma; + void Validate(const WorkloadInfo& workloadInfo) const; +}; + struct L2NormalizationQueueDescriptor : QueueDescriptorWithParameters { void Validate(const WorkloadInfo& workloadInfo) const; diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index 44888b3ac9..98fe158fc5 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -371,6 +371,21 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason); break; } + case LayerType::InstanceNormalization: + { + auto cLayer = boost::polymorphic_downcast(&layer); + const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters(); + + const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); + const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); + + result = layerSupportObject->IsInstanceNormalizationSupported( + OverrideDataType(input, dataType), + OverrideDataType(output, dataType), + descriptor, + reason); + break; + } case LayerType::L2Normalization: { auto cLayer = boost::polymorphic_downcast(&layer); @@ -1139,6 +1154,13 @@ std::unique_ptr IWorkloadFactory::CreateGreater(const GreaterQueueDes return std::unique_ptr(); } +std::unique_ptr IWorkloadFactory::CreateInstanceNormalization( + const InstanceNormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::unique_ptr(); +} + std::unique_ptr IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const { diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp index 2809e2f9e8..9fa0221f31 100644 --- a/src/backends/backendsCommon/WorkloadFactory.hpp +++ b/src/backends/backendsCommon/WorkloadFactory.hpp @@ -120,6 +120,10 @@ public: virtual std::unique_ptr CreateGreater(const GreaterQueueDescriptor& descriptor, const WorkloadInfo& info) const; + virtual std::unique_ptr CreateInstanceNormalization( + const InstanceNormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + virtual std::unique_ptr CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const; diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp index e492cd6908..c8604140ec 100644 --- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp +++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp @@ -435,6 +435,8 @@ DECLARE_LAYER_POLICY_1_PARAM(Greater) DECLARE_LAYER_POLICY_CUSTOM_PARAM(Input, armnn::LayerBindingId) +DECLARE_LAYER_POLICY_2_PARAM(InstanceNormalization) + DECLARE_LAYER_POLICY_2_PARAM(L2Normalization) DECLARE_LAYER_POLICY_2_PARAM(Lstm) -- cgit v1.2.1