aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.cpp8
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.hpp5
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp40
-rw-r--r--src/backends/backendsCommon/WorkloadData.hpp5
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp17
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.hpp3
-rw-r--r--src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp2
7 files changed, 80 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp
index 48705c89b2..12e4ee81ae 100644
--- a/src/backends/backendsCommon/LayerSupportBase.cpp
+++ b/src/backends/backendsCommon/LayerSupportBase.cpp
@@ -348,6 +348,14 @@ bool LayerSupportBase::IsPreCompiledSupported(const TensorInfo& input,
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
+bool LayerSupportBase::IsPreluSupported(const TensorInfo& input,
+ const TensorInfo& alpha,
+ const TensorInfo& output,
+ Optional<std::string &> reasonIfUnsupported) const
+{
+ return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
+}
+
bool LayerSupportBase::IsQuantizeSupported(const armnn::TensorInfo& input,
const armnn::TensorInfo& output,
armnn::Optional<std::string&> reasonIfUnsupported) const
diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp
index 4921cf9b3e..d035dfcd62 100644
--- a/src/backends/backendsCommon/LayerSupportBase.hpp
+++ b/src/backends/backendsCommon/LayerSupportBase.hpp
@@ -221,6 +221,11 @@ public:
const PreCompiledDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsPreluSupported(const TensorInfo& input,
+ const TensorInfo& alpha,
+ const TensorInfo& output,
+ Optional<std::string &> reasonIfUnsupported) const override;
+
bool IsQuantizeSupported(const TensorInfo& input,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 7c9d4ac58c..d8c10bdea6 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1711,4 +1711,44 @@ void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
// This is internally generated so it should not need validation.
}
+void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ ValidateNumInputs(workloadInfo, "PreluQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "PreluQueueDescriptor", 1);
+
+ std::vector<DataType> supportedTypes
+ {
+ DataType::Float16,
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "PreluQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[1],
+ supportedTypes,
+ "PreluQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "PreluQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ { workloadInfo.m_InputTensorInfos[1].GetDataType() },
+ "PreluQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ { workloadInfo.m_OutputTensorInfos[0].GetDataType() },
+ "PreluQueueDescriptor");
+
+ ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
+ workloadInfo.m_InputTensorInfos[1],
+ workloadInfo.m_OutputTensorInfos[0],
+ "PreluQueueDescriptor",
+ "input",
+ "alpha");
+}
+
} //namespace armnn
diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp
index 501fdd8464..6a51bc3144 100644
--- a/src/backends/backendsCommon/WorkloadData.hpp
+++ b/src/backends/backendsCommon/WorkloadData.hpp
@@ -440,4 +440,9 @@ struct SwitchQueueDescriptor : QueueDescriptor
void Validate(const WorkloadInfo& workloadInfo) const;
};
+struct PreluQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
} //namespace armnn
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 678d330508..cca39198e1 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -785,6 +785,17 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
reason);
break;
}
+ case LayerType::Prelu:
+ {
+ const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
+ const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
+ const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+ result = layerSupportObject->IsPreluSupported(OverrideDataType(input, dataType),
+ OverrideDataType(alpha, dataType),
+ OverrideDataType(output, dataType),
+ reason);
+ break;
+ }
default:
{
BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
@@ -1015,6 +1026,12 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiled
return std::unique_ptr<IWorkload>();
}
+std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &descriptor,
+ const WorkloadInfo &info) const
+{
+ return std::unique_ptr<IWorkload>();
+}
+
std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor,
const WorkloadInfo& Info) const
{
diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp
index cc99356b73..c9fbe71f96 100644
--- a/src/backends/backendsCommon/WorkloadFactory.hpp
+++ b/src/backends/backendsCommon/WorkloadFactory.hpp
@@ -155,6 +155,9 @@ public:
virtual std::unique_ptr<IWorkload> CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
const WorkloadInfo& info) const;
+ virtual std::unique_ptr<IWorkload> CreatePrelu(const PreluQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
virtual std::unique_ptr<IWorkload> CreateQuantize(const QuantizeQueueDescriptor& descriptor,
const WorkloadInfo& Info) const;
diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
index ff632fc701..111cf8f3e3 100644
--- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
+++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
@@ -384,6 +384,8 @@ DECLARE_LAYER_POLICY_2_PARAM(Pooling2d)
DECLARE_LAYER_POLICY_2_PARAM(PreCompiled)
+DECLARE_LAYER_POLICY_1_PARAM(Prelu)
+
DECLARE_LAYER_POLICY_1_PARAM(Division)
DECLARE_LAYER_POLICY_2_PARAM(ResizeBilinear)