aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2019-06-13 17:27:46 +0100
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-06-19 11:12:35 +0000
commitab9e52563f624d9782b97400f643d2632cc8d770 (patch)
tree5fec7f9e398dbd2337241cd3cc908f7b0e1d588b
parentbee4bc944aa50782ff22cb4a31fbc611212a5e89 (diff)
downloadarmnn-ab9e52563f624d9782b97400f643d2632cc8d770.tar.gz
IVGCVSW-3268 Add Reference workload support for the new Prelu Activation layer
* Added reference workload for the PReLU Activation layer * Added factory methods * Added validation support * Added Int16 support * Added unit tests Change-Id: Ic950d908c5e0a335dccd2960a3ffab0f8b599876 Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
-rw-r--r--src/armnn/test/CreateWorkload.hpp33
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp3
-rw-r--r--src/backends/backendsCommon/test/LayerTests.hpp92
-rw-r--r--src/backends/reference/RefLayerSupport.cpp32
-rw-r--r--src/backends/reference/RefLayerSupport.hpp5
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp10
-rw-r--r--src/backends/reference/RefWorkloadFactory.hpp3
-rw-r--r--src/backends/reference/backend.mk2
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp28
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp5
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt4
-rw-r--r--src/backends/reference/workloads/PreluImpl.cpp35
-rw-r--r--src/backends/reference/workloads/PreluImpl.hpp21
-rw-r--r--src/backends/reference/workloads/RefPreluWorkload.cpp35
-rw-r--r--src/backends/reference/workloads/RefPreluWorkload.hpp22
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp1
16 files changed, 330 insertions, 1 deletions
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp
index 8863fecce3..b075744434 100644
--- a/src/armnn/test/CreateWorkload.hpp
+++ b/src/armnn/test/CreateWorkload.hpp
@@ -1280,4 +1280,37 @@ std::unique_ptr<ConstantWorkload> CreateConstantWorkloadTest(armnn::IWorkloadFac
return workloadConstant;
}
+template <typename PreluWorkload, armnn::DataType DataType>
+std::unique_ptr<PreluWorkload> CreatePreluWorkloadTest(armnn::IWorkloadFactory& factory,
+ armnn::Graph& graph,
+ const armnn::TensorShape& outputShape)
+{
+ // Creates the PReLU layer
+ Layer* const layer = graph.AddLayer<PreluLayer>("prelu");
+
+ // Creates extra layers
+ Layer* const input = graph.AddLayer<InputLayer> (0, "input");
+ Layer* const alpha = graph.AddLayer<InputLayer> (1, "alpha");
+ Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
+
+ // Connects up
+ armnn::TensorInfo inputTensorInfo ({ 1, 4, 1, 2 }, DataType);
+ armnn::TensorInfo alphaTensorInfo ({ 5, 4, 3, 1 }, DataType);
+ armnn::TensorInfo outputTensorInfo(outputShape, DataType);
+ Connect(input, layer, inputTensorInfo, 0, 0);
+ Connect(alpha, layer, alphaTensorInfo, 0, 1);
+ Connect(layer, output, outputTensorInfo, 0, 0);
+ CreateTensorHandles(graph, factory);
+
+ // Makes the workload and checks it
+ auto workload = MakeAndCheckWorkload<PreluWorkload>(*layer, graph, factory);
+
+ PreluQueueDescriptor queueDescriptor = workload->GetData();
+ BOOST_TEST(queueDescriptor.m_Inputs.size() == 2);
+ BOOST_TEST(queueDescriptor.m_Outputs.size() == 1);
+
+ // Returns so we can do extra, backend-specific tests.
+ return workload;
}
+
+} // Anonymous namespace
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index d8c10bdea6..b7317af9cd 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1720,7 +1720,8 @@ void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
DataType::Float16,
DataType::Float32,
- DataType::QuantisedAsymm8
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
};
ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
diff --git a/src/backends/backendsCommon/test/LayerTests.hpp b/src/backends/backendsCommon/test/LayerTests.hpp
index 058d6946f6..bf0d063091 100644
--- a/src/backends/backendsCommon/test/LayerTests.hpp
+++ b/src/backends/backendsCommon/test/LayerTests.hpp
@@ -3158,3 +3158,95 @@ LayerTestResult<T, 3> ConcatDifferentInputOutputQParamTest(
return ret;
}
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 4> PreluTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ armnn::TensorInfo inputTensorInfo ({ 1, 2, 2, 3 }, ArmnnType);
+ armnn::TensorInfo alphaTensorInfo ({ 1, 1, 1, 3 }, ArmnnType);
+ armnn::TensorInfo outputTensorInfo({ 1, 2, 2, 3 }, ArmnnType);
+
+ if (armnn::IsQuantizedType<T>())
+ {
+ inputTensorInfo.SetQuantizationScale(0.25f);
+ inputTensorInfo.SetQuantizationOffset(128);
+ alphaTensorInfo.SetQuantizationScale(0.25f);
+ alphaTensorInfo.SetQuantizationOffset(50);
+ outputTensorInfo.SetQuantizationScale(0.5f);
+ outputTensorInfo.SetQuantizationOffset(120);
+ }
+
+ std::vector<float> inputData
+ {
+ // Expected quantized values:
+ // 128, 128, 128, 132, 132, 132, 124, 124, 124, 120, 120, 120
+ 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -2.0f, -2.0f, -2.0f
+ };
+ std::vector<float> alphaData
+ {
+ // Expected quantized values:
+ // 50, 54, 58
+ 0.0f, 1.0f, 2.0f
+ };
+ std::vector<float> outputExpectedData =
+ {
+ // Expected quantized values:
+ // 20, 120, 120, 122, 122, 122, 120, 118, 116, 120, 116, 112
+ 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 0.0f, -1.0f, -2.0f, 0.0f, -2.0f, -4.0f
+ };
+
+ auto input = MakeTensor<T, 4>(inputTensorInfo, QuantizedVector<T>(inputTensorInfo.GetQuantizationScale(),
+ inputTensorInfo.GetQuantizationOffset(),
+ inputData));
+ auto alpha = MakeTensor<T, 4>(alphaTensorInfo, QuantizedVector<T>(alphaTensorInfo.GetQuantizationScale(),
+ alphaTensorInfo.GetQuantizationOffset(),
+ alphaData));
+
+ LayerTestResult<T, 4> result(outputTensorInfo);
+ result.outputExpected = MakeTensor<T, 4>(outputTensorInfo,
+ QuantizedVector<T>(outputTensorInfo.GetQuantizationScale(),
+ outputTensorInfo.GetQuantizationOffset(),
+ outputExpectedData));
+
+ std::unique_ptr <armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
+ std::unique_ptr <armnn::ITensorHandle> alphaHandle = workloadFactory.CreateTensorHandle(alphaTensorInfo);
+ std::unique_ptr <armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
+
+ armnn::PreluQueueDescriptor descriptor;
+ armnn::WorkloadInfo info;
+ AddInputToWorkload (descriptor, info, inputTensorInfo, inputHandle.get());
+ AddInputToWorkload (descriptor, info, alphaTensorInfo, alphaHandle.get());
+ AddOutputToWorkload(descriptor, info, outputTensorInfo, outputHandle.get());
+
+ std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreatePrelu(descriptor, info);
+
+ inputHandle->Allocate();
+ alphaHandle->Allocate();
+ outputHandle->Allocate();
+
+ CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
+ CopyDataToITensorHandle(alphaHandle.get(), &alpha[0][0][0][0]);
+
+ workload->Execute();
+
+ CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
+
+ return result;
+}
+
+template LayerTestResult<typename armnn::ResolveType<armnn::DataType::Float32>, 4>
+PreluTest<armnn::DataType::Float32>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
+template LayerTestResult<typename armnn::ResolveType<armnn::DataType::QuantisedAsymm8>, 4>
+PreluTest<armnn::DataType::QuantisedAsymm8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
+template LayerTestResult<typename armnn::ResolveType<armnn::DataType::QuantisedSymm16>, 4>
+PreluTest<armnn::DataType::QuantisedSymm16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 919dd5fd6c..077aa1ce3a 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -1353,4 +1353,36 @@ bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
return supported;
}
+bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
+ const TensorInfo& alpha,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ bool supported = true;
+
+ std::array<DataType, 3> supportedTypes
+ {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+ "PReLU: input is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
+ "PReLU: alpha is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "PReLU: output is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
+ "PReLU: input, alpha and output types are mismatched");
+
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
+ "PReLU: shapes are not suitable for implicit broadcast");
+
+ return supported;
+}
+
} // namespace armnn
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index 8850c6e105..041701d8e1 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -255,6 +255,11 @@ public:
const TensorInfo& input1,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
+ bool IsPreluSupported(const TensorInfo& input,
+ const TensorInfo& alpha,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
};
} // namespace armnn
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 4467bd4ad6..29b2c52254 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -436,4 +436,14 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDequantize(const Dequantize
return std::make_unique<RefDequantizeWorkload>(descriptor, info);
}
+std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePrelu(const PreluQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ if (IsFloat16(info))
+ {
+ return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+ }
+ return std::make_unique<RefPreluWorkload>(descriptor, info);
+}
+
} // namespace armnn
diff --git a/src/backends/reference/RefWorkloadFactory.hpp b/src/backends/reference/RefWorkloadFactory.hpp
index 78f6bab92c..333a9ca257 100644
--- a/src/backends/reference/RefWorkloadFactory.hpp
+++ b/src/backends/reference/RefWorkloadFactory.hpp
@@ -187,6 +187,9 @@ public:
std::unique_ptr<IWorkload> CreateQuantize(const QuantizeQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreatePrelu(const PreluQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
+
private:
template <typename F32Workload, typename U8Workload, typename QueueDescriptorType>
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index ecd281208a..a430f4fb68 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -25,6 +25,7 @@ BACKEND_SOURCES := \
workloads/Concatenate.cpp \
workloads/Pad.cpp \
workloads/Pooling2d.cpp \
+ workloads/PreluImpl.cpp \
workloads/RefActivationWorkload.cpp \
workloads/RefBatchNormalizationWorkload.cpp \
workloads/RefBatchToSpaceNdFloat32Workload.cpp \
@@ -50,6 +51,7 @@ BACKEND_SOURCES := \
workloads/RefPadWorkload.cpp \
workloads/RefPermuteWorkload.cpp \
workloads/RefPooling2dWorkload.cpp \
+ workloads/RefPreluWorkload.cpp \
workloads/RefQuantizeWorkload.cpp \
workloads/RefReshapeWorkload.cpp \
workloads/RefResizeBilinearWorkload.cpp \
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index e541692654..14615f89df 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -870,4 +870,32 @@ BOOST_AUTO_TEST_CASE(CreateConstantSigned32Workload)
RefCreateConstantWorkloadTest<RefConstantWorkload, armnn::DataType::Signed32>({ 2, 3, 2, 10 });
}
+template <typename armnn::DataType DataType>
+static void RefCreatePreluWorkloadTest(const armnn::TensorShape& outputShape)
+{
+ armnn::Graph graph;
+ RefWorkloadFactory factory;
+ auto workload = CreatePreluWorkloadTest<RefPreluWorkload, DataType>(factory, graph, outputShape);
+
+ // Check output is as expected
+ auto queueDescriptor = workload->GetData();
+ auto outputHandle = boost::polymorphic_downcast<CpuTensorHandle*>(queueDescriptor.m_Outputs[0]);
+ BOOST_TEST((outputHandle->GetTensorInfo() == TensorInfo(outputShape, DataType)));
+}
+
+BOOST_AUTO_TEST_CASE(CreatePreluFloat32Workload)
+{
+ RefCreatePreluWorkloadTest<armnn::DataType::Float32>({ 5, 4, 3, 2 });
+}
+
+BOOST_AUTO_TEST_CASE(CreatePreluUint8Workload)
+{
+ RefCreatePreluWorkloadTest<armnn::DataType::QuantisedAsymm8>({ 5, 4, 3, 2 });
+}
+
+BOOST_AUTO_TEST_CASE(CreatePreluInt16Workload)
+{
+ RefCreatePreluWorkloadTest<armnn::DataType::QuantisedSymm16>({ 5, 4, 3, 2 });
+}
+
BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 95e93653bc..b540d185d3 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -859,4 +859,9 @@ ARMNN_AUTO_TEST_CASE(QuantizeSimpleUint8, QuantizeSimpleUint8Test)
ARMNN_AUTO_TEST_CASE(QuantizeClampUint8, QuantizeClampUint8Test)
ARMNN_AUTO_TEST_CASE(QuantizeClampInt16, QuantizeClampInt16Test)
+// PReLU
+ARMNN_AUTO_TEST_CASE(PreluFloat32, PreluTest<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(PreluUint8, PreluTest<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(PreluInt16, PreluTest<armnn::DataType::QuantisedSymm16>)
+
BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 1ab38ccbcb..db0daa0310 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -36,6 +36,8 @@ list(APPEND armnnRefBackendWorkloads_sources
Pad.hpp
Pooling2d.cpp
Pooling2d.hpp
+ PreluImpl.cpp
+ PreluImpl.hpp
RefActivationWorkload.cpp
RefActivationWorkload.hpp
RefBatchNormalizationWorkload.cpp
@@ -84,6 +86,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefPermuteWorkload.hpp
RefPooling2dWorkload.cpp
RefPooling2dWorkload.hpp
+ RefPreluWorkload.cpp
+ RefPreluWorkload.hpp
RefQuantizeWorkload.cpp
RefQuantizeWorkload.hpp
RefReshapeWorkload.cpp
diff --git a/src/backends/reference/workloads/PreluImpl.cpp b/src/backends/reference/workloads/PreluImpl.cpp
new file mode 100644
index 0000000000..458025bb0a
--- /dev/null
+++ b/src/backends/reference/workloads/PreluImpl.cpp
@@ -0,0 +1,35 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "PreluImpl.hpp"
+#include "RefWorkloadUtils.hpp"
+#include "Broadcast.hpp"
+
+namespace armnn
+{
+
+void PreluImpl(const PreluQueueDescriptor& data,
+ Decoder<float>& inputData,
+ Decoder<float>& alphaData,
+ Encoder<float>& outputData)
+{
+ const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[0]);
+ const TensorInfo& alphaInfo = GetTensorInfo(data.m_Inputs[1]);
+ const TensorInfo& outputInfo = GetTensorInfo(data.m_Outputs[0]);
+
+ const TensorShape& inputShape = inputInfo.GetShape();
+ const TensorShape& alphaShape = alphaInfo.GetShape();
+ const TensorShape& outputShape = outputInfo.GetShape();
+
+ // PReLU activation: f(x) = alpha * x for x < 0, f(x) = x for x >= 0
+ auto prelu = [](float x, float alpha)
+ {
+ return x < 0 ? alpha * x : x;
+ };
+
+ BroadcastLoop(inputShape, alphaShape, outputShape).Unroll(prelu, 0, inputData, alphaData, outputData);
+}
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/PreluImpl.hpp b/src/backends/reference/workloads/PreluImpl.hpp
new file mode 100644
index 0000000000..9299b1c7f7
--- /dev/null
+++ b/src/backends/reference/workloads/PreluImpl.hpp
@@ -0,0 +1,21 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "Encoders.hpp"
+#include "Decoders.hpp"
+
+#include <backendsCommon/WorkloadData.hpp>
+
+namespace armnn
+{
+
+void PreluImpl(const PreluQueueDescriptor& data,
+ Decoder<float>& inputData,
+ Decoder<float>& alphaData,
+ Encoder<float>& outputData);
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefPreluWorkload.cpp b/src/backends/reference/workloads/RefPreluWorkload.cpp
new file mode 100644
index 0000000000..cdc0a63711
--- /dev/null
+++ b/src/backends/reference/workloads/RefPreluWorkload.cpp
@@ -0,0 +1,35 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefPreluWorkload.hpp"
+
+#include "RefWorkloadUtils.hpp"
+#include "PreluImpl.hpp"
+
+#include <Profiling.hpp>
+
+namespace armnn
+{
+
+RefPreluWorkload::RefPreluWorkload(const PreluQueueDescriptor& descriptor,
+ const WorkloadInfo& info)
+ : BaseWorkload(descriptor, info)
+{}
+
+void RefPreluWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefPreluWorkload_Execute");
+
+ std::unique_ptr<Decoder<float>> inputDecoder = MakeDecoder<float>(GetTensorInfo(m_Data.m_Inputs[0]),
+ m_Data.m_Inputs[0]->Map());
+ std::unique_ptr<Decoder<float>> alphaDecoder = MakeDecoder<float>(GetTensorInfo(m_Data.m_Inputs[1]),
+ m_Data.m_Inputs[1]->Map());
+ std::unique_ptr<Encoder<float>> outputEncoder = MakeEncoder<float>(GetTensorInfo(m_Data.m_Outputs[0]),
+ m_Data.m_Outputs[0]->Map());
+
+ PreluImpl(m_Data, *inputDecoder, *alphaDecoder, *outputEncoder);
+}
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefPreluWorkload.hpp b/src/backends/reference/workloads/RefPreluWorkload.hpp
new file mode 100644
index 0000000000..72839e67dc
--- /dev/null
+++ b/src/backends/reference/workloads/RefPreluWorkload.hpp
@@ -0,0 +1,22 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+
+namespace armnn
+{
+
+class RefPreluWorkload : public BaseWorkload<PreluQueueDescriptor>
+{
+public:
+ explicit RefPreluWorkload(const PreluQueueDescriptor& descriptor,
+ const WorkloadInfo& info);
+ virtual void Execute() const override;
+};
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index b14129146a..41b16fa56f 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -51,3 +51,4 @@
#include "RefDequantizeWorkload.hpp"
#include "RefQuantizeWorkload.hpp"
#include "RefReshapeWorkload.hpp"
+#include "RefPreluWorkload.hpp"