From 69e653f9b2a7c8a2ab0cd3556b246a9df21b81d6 Mon Sep 17 00:00:00 2001 From: Keith Davis Date: Thu, 2 Jul 2020 11:49:26 +0100 Subject: IVGCVSW-3897 Add NEON LOG_SOFTMAX Workload Signed-off-by: Keith Davis Change-Id: I632b5ac7f188853de68e232e81568b3fca238d42 --- src/armnn/test/CreateWorkload.hpp | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) (limited to 'src/armnn') diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp index f484a21f48..aad6244c4b 100644 --- a/src/armnn/test/CreateWorkload.hpp +++ b/src/armnn/test/CreateWorkload.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -1262,6 +1262,41 @@ std::unique_ptr CreateBatchToSpaceNdWorkloadTest(armnn:: return workload; } +template +std::unique_ptr CreateLogSoftmaxWorkloadTest(armnn::IWorkloadFactory& factory, + armnn::Graph& graph) +{ + // Create the layer we're testing. + LogSoftmaxDescriptor logSoftmaxDescriptor; + // Set Axis to 1 if CL or Neon until further Axes are supported. + if (factory.GetBackendId() == armnn::Compute::CpuAcc || factory.GetBackendId() == armnn::Compute::GpuAcc) + { + logSoftmaxDescriptor.m_Axis = 0; + } + + Layer* const layer = graph.AddLayer(logSoftmaxDescriptor, "layer"); + // Create extra layers. + Layer* const input = graph.AddLayer(0, "input"); + Layer* const output = graph.AddLayer(0, "output"); + + // Connect up + armnn::TensorInfo tensorInfo({4, 1}, DataType); + + Connect(input, layer, tensorInfo); + Connect(layer, output, tensorInfo); + CreateTensorHandles(graph, factory); + + // Make the workload and checks it. + auto workload = MakeAndCheckWorkload(*layer, factory); + + LogSoftmaxQueueDescriptor queueDescriptor = workload->GetData(); + BOOST_TEST(queueDescriptor.m_Inputs.size() == 1); + BOOST_TEST(queueDescriptor.m_Outputs.size() == 1); + + // Return so we can do extra, backend-specific tests. + return workload; +} + template std::unique_ptr CreateL2NormalizationWorkloadTest(armnn::IWorkloadFactory& factory, armnn::Graph& graph, DataLayout dataLayout = DataLayout::NCHW) -- cgit v1.2.1