diff options
Diffstat (limited to 'src/armnn/test')
-rw-r--r-- | src/armnn/test/CreateWorkload.hpp | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp index 135a4421cd..87df00af3c 100644 --- a/src/armnn/test/CreateWorkload.hpp +++ b/src/armnn/test/CreateWorkload.hpp @@ -892,6 +892,34 @@ std::unique_ptr<ResizeBilinearWorkload> CreateResizeBilinearWorkloadTest(armnn:: return workload; } +template <typename RsqrtWorkload, armnn::DataType DataType> +std::unique_ptr<RsqrtWorkload> CreateRsqrtWorkloadTest(armnn::IWorkloadFactory& factory, + armnn::Graph& graph) +{ + Layer* const layer = graph.AddLayer<RsqrtLayer>("rsqrt"); + + // Creates extra layers. + Layer* const input = graph.AddLayer<InputLayer>(0, "input"); + Layer* const output = graph.AddLayer<OutputLayer>(0, "output"); + + // Connects up. + armnn::TensorInfo tensorInfo({1, 1}, DataType); + + Connect(input, layer, tensorInfo); + Connect(layer, output, tensorInfo); + + CreateTensorHandles(graph, factory); + + // Makes the workload and checks it. + auto workload = MakeAndCheckWorkload<RsqrtWorkload>(*layer, graph, factory); + + RsqrtQueueDescriptor queueDescriptor = workload->GetData(); + BOOST_TEST(queueDescriptor.m_Inputs.size() == 1); + BOOST_TEST(queueDescriptor.m_Outputs.size() == 1); + + return workload; +} + template <typename L2NormalizationWorkload, armnn::DataType DataType> std::unique_ptr<L2NormalizationWorkload> CreateL2NormalizationWorkloadTest(armnn::IWorkloadFactory& factory, armnn::Graph& graph, DataLayout dataLayout = DataLayout::NCHW) |