aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.cpp8
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.hpp6
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp46
-rw-r--r--src/backends/backendsCommon/WorkloadData.hpp15
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp22
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.hpp4
-rw-r--r--src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp2
7 files changed, 103 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp
index 656407d020..c41f0b11ea 100644
--- a/src/backends/backendsCommon/LayerSupportBase.cpp
+++ b/src/backends/backendsCommon/LayerSupportBase.cpp
@@ -234,6 +234,14 @@ bool LayerSupportBase::IsInputSupported(const TensorInfo& input,
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
+bool LayerSupportBase::IsInstanceNormalizationSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const InstanceNormalizationDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
+}
+
bool LayerSupportBase::IsL2NormalizationSupported(const TensorInfo& input,
const TensorInfo& output,
const L2NormalizationDescriptor& descriptor,
diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp
index c3875e6ced..495870e645 100644
--- a/src/backends/backendsCommon/LayerSupportBase.hpp
+++ b/src/backends/backendsCommon/LayerSupportBase.hpp
@@ -141,6 +141,12 @@ public:
bool IsInputSupported(const TensorInfo& input,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsInstanceNormalizationSupported(
+ const TensorInfo& input,
+ const TensorInfo& output,
+ const InstanceNormalizationDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsL2NormalizationSupported(const TensorInfo& input,
const TensorInfo& output,
const L2NormalizationDescriptor& descriptor,
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index e49fd09be0..aca5023f97 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1233,6 +1233,52 @@ void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo)
}
}
+void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
+
+ ValidateNumInputs(workloadInfo, descriptorName, 1);
+ ValidateNumOutputs(workloadInfo, descriptorName, 1);
+
+ const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
+ const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
+
+ if (inputTensorInfo.GetNumDimensions() > 4)
+ {
+ throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
+ }
+
+ ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+
+ // Check the supported data types
+ std::vector<DataType> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::Float16
+ };
+
+ ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
+ ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
+
+ ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+
+ ValidatePointer(m_Beta, descriptorName, "beta");
+ ValidatePointer(m_Eps, descriptorName, "epsilon");
+ ValidatePointer(m_Gamma, descriptorName, "gamma");
+
+ const TensorInfo& beta = m_Beta->GetTensorInfo();
+ const TensorInfo& epsilon = m_Eps->GetTensorInfo();
+ const TensorInfo& gamma = m_Gamma->GetTensorInfo();
+
+ ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
+ ValidateTensorNumDimensions(epsilon, descriptorName, 1, "epsilon");
+ ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
+
+ ValidateTensorDataTypesMatch(inputTensorInfo, beta, descriptorName, "input", "beta");
+ ValidateTensorDataTypesMatch(inputTensorInfo, epsilon, descriptorName, "input", "epsilon");
+ ValidateTensorDataTypesMatch(inputTensorInfo, gamma, descriptorName, "input", "gamma");
+}
+
void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
const std::string descriptorName{"L2NormalizationQueueDescriptor"};
diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp
index 177bfb7af3..14d7b588e1 100644
--- a/src/backends/backendsCommon/WorkloadData.hpp
+++ b/src/backends/backendsCommon/WorkloadData.hpp
@@ -307,6 +307,21 @@ struct FakeQuantizationQueueDescriptor : QueueDescriptorWithParameters<FakeQuant
void Validate(const WorkloadInfo& workloadInfo) const;
};
+struct InstanceNormalizationQueueDescriptor : QueueDescriptorWithParameters<InstanceNormalizationDescriptor>
+{
+ InstanceNormalizationQueueDescriptor()
+ : m_Beta(nullptr)
+ , m_Eps(nullptr)
+ , m_Gamma(nullptr)
+ {
+ }
+
+ const ConstCpuTensorHandle* m_Beta;
+ const ConstCpuTensorHandle* m_Eps;
+ const ConstCpuTensorHandle* m_Gamma;
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
struct L2NormalizationQueueDescriptor : QueueDescriptorWithParameters<L2NormalizationDescriptor>
{
void Validate(const WorkloadInfo& workloadInfo) const;
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 44888b3ac9..98fe158fc5 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -371,6 +371,21 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
break;
}
+ case LayerType::InstanceNormalization:
+ {
+ auto cLayer = boost::polymorphic_downcast<const InstanceNormalizationLayer*>(&layer);
+ const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
+
+ const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
+ const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+
+ result = layerSupportObject->IsInstanceNormalizationSupported(
+ OverrideDataType(input, dataType),
+ OverrideDataType(output, dataType),
+ descriptor,
+ reason);
+ break;
+ }
case LayerType::L2Normalization:
{
auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
@@ -1139,6 +1154,13 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateGreater(const GreaterQueueDes
return std::unique_ptr<IWorkload>();
}
+std::unique_ptr<IWorkload> IWorkloadFactory::CreateInstanceNormalization(
+ const InstanceNormalizationQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return std::unique_ptr<IWorkload>();
+}
+
std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp
index 2809e2f9e8..9fa0221f31 100644
--- a/src/backends/backendsCommon/WorkloadFactory.hpp
+++ b/src/backends/backendsCommon/WorkloadFactory.hpp
@@ -120,6 +120,10 @@ public:
virtual std::unique_ptr<IWorkload> CreateGreater(const GreaterQueueDescriptor& descriptor,
const WorkloadInfo& info) const;
+ virtual std::unique_ptr<IWorkload> CreateInstanceNormalization(
+ const InstanceNormalizationQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
virtual std::unique_ptr<IWorkload> CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
const WorkloadInfo& info) const;
diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
index e492cd6908..c8604140ec 100644
--- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
+++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
@@ -435,6 +435,8 @@ DECLARE_LAYER_POLICY_1_PARAM(Greater)
DECLARE_LAYER_POLICY_CUSTOM_PARAM(Input, armnn::LayerBindingId)
+DECLARE_LAYER_POLICY_2_PARAM(InstanceNormalization)
+
DECLARE_LAYER_POLICY_2_PARAM(L2Normalization)
DECLARE_LAYER_POLICY_2_PARAM(Lstm)