aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/test
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/test')
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp42
1 files changed, 35 insertions, 7 deletions
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 14615f89df..d174093fd4 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -870,32 +870,60 @@ BOOST_AUTO_TEST_CASE(CreateConstantSigned32Workload)
RefCreateConstantWorkloadTest<RefConstantWorkload, armnn::DataType::Signed32>({ 2, 3, 2, 10 });
}
-template <typename armnn::DataType DataType>
-static void RefCreatePreluWorkloadTest(const armnn::TensorShape& outputShape)
+static void RefCreatePreluWorkloadTest(const armnn::TensorShape& inputShape,
+ const armnn::TensorShape& alphaShape,
+ const armnn::TensorShape& outputShape,
+ armnn::DataType dataType)
{
armnn::Graph graph;
RefWorkloadFactory factory;
- auto workload = CreatePreluWorkloadTest<RefPreluWorkload, DataType>(factory, graph, outputShape);
+ auto workload = CreatePreluWorkloadTest<RefPreluWorkload>(factory,
+ graph,
+ inputShape,
+ alphaShape,
+ outputShape,
+ dataType);
// Check output is as expected
auto queueDescriptor = workload->GetData();
auto outputHandle = boost::polymorphic_downcast<CpuTensorHandle*>(queueDescriptor.m_Outputs[0]);
- BOOST_TEST((outputHandle->GetTensorInfo() == TensorInfo(outputShape, DataType)));
+ BOOST_TEST((outputHandle->GetTensorInfo() == TensorInfo(outputShape, dataType)));
}
BOOST_AUTO_TEST_CASE(CreatePreluFloat32Workload)
{
- RefCreatePreluWorkloadTest<armnn::DataType::Float32>({ 5, 4, 3, 2 });
+ RefCreatePreluWorkloadTest({ 1, 4, 1, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 }, armnn::DataType::Float32);
}
BOOST_AUTO_TEST_CASE(CreatePreluUint8Workload)
{
- RefCreatePreluWorkloadTest<armnn::DataType::QuantisedAsymm8>({ 5, 4, 3, 2 });
+ RefCreatePreluWorkloadTest({ 1, 4, 1, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 }, armnn::DataType::QuantisedAsymm8);
}
BOOST_AUTO_TEST_CASE(CreatePreluInt16Workload)
{
- RefCreatePreluWorkloadTest<armnn::DataType::QuantisedSymm16>({ 5, 4, 3, 2 });
+ RefCreatePreluWorkloadTest({ 1, 4, 1, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 }, armnn::DataType::QuantisedSymm16);
+}
+
+BOOST_AUTO_TEST_CASE(CreatePreluFloat32NoBroadcastWorkload)
+{
+ BOOST_CHECK_THROW(RefCreatePreluWorkloadTest({ 1, 4, 7, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 },
+ armnn::DataType::Float32),
+ armnn::InvalidArgumentException);
+}
+
+BOOST_AUTO_TEST_CASE(CreatePreluUint8NoBroadcastWorkload)
+{
+ BOOST_CHECK_THROW(RefCreatePreluWorkloadTest({ 1, 4, 7, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 },
+ armnn::DataType::QuantisedAsymm8),
+ armnn::InvalidArgumentException);
+}
+
+BOOST_AUTO_TEST_CASE(CreatePreluInt16NoBroadcastWorkload)
+{
+ BOOST_CHECK_THROW(RefCreatePreluWorkloadTest({ 1, 4, 7, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 },
+ armnn::DataType::QuantisedSymm16),
+ armnn::InvalidArgumentException);
}
BOOST_AUTO_TEST_SUITE_END()