diff options
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/test/CreateWorkload.hpp | 37 |
1 files changed, 36 insertions, 1 deletions
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp index f484a21f48..aad6244c4b 100644 --- a/src/armnn/test/CreateWorkload.hpp +++ b/src/armnn/test/CreateWorkload.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -1262,6 +1262,41 @@ std::unique_ptr<BatchToSpaceNdWorkload> CreateBatchToSpaceNdWorkloadTest(armnn:: return workload; } +template <typename LogSoftmaxWorkload, armnn::DataType DataType> +std::unique_ptr<LogSoftmaxWorkload> CreateLogSoftmaxWorkloadTest(armnn::IWorkloadFactory& factory, + armnn::Graph& graph) +{ + // Create the layer we're testing. + LogSoftmaxDescriptor logSoftmaxDescriptor; + // Set Axis to 1 if CL or Neon until further Axes are supported. + if (factory.GetBackendId() == armnn::Compute::CpuAcc || factory.GetBackendId() == armnn::Compute::GpuAcc) + { + logSoftmaxDescriptor.m_Axis = 0; + } + + Layer* const layer = graph.AddLayer<LogSoftmaxLayer>(logSoftmaxDescriptor, "layer"); + // Create extra layers. + Layer* const input = graph.AddLayer<InputLayer>(0, "input"); + Layer* const output = graph.AddLayer<OutputLayer>(0, "output"); + + // Connect up + armnn::TensorInfo tensorInfo({4, 1}, DataType); + + Connect(input, layer, tensorInfo); + Connect(layer, output, tensorInfo); + CreateTensorHandles(graph, factory); + + // Make the workload and checks it. + auto workload = MakeAndCheckWorkload<LogSoftmaxWorkload>(*layer, factory); + + LogSoftmaxQueueDescriptor queueDescriptor = workload->GetData(); + BOOST_TEST(queueDescriptor.m_Inputs.size() == 1); + BOOST_TEST(queueDescriptor.m_Outputs.size() == 1); + + // Return so we can do extra, backend-specific tests. + return workload; +} + template <typename L2NormalizationWorkload, armnn::DataType DataType> std::unique_ptr<L2NormalizationWorkload> CreateL2NormalizationWorkloadTest(armnn::IWorkloadFactory& factory, armnn::Graph& graph, DataLayout dataLayout = DataLayout::NCHW) |