From 4951d84b1174a4bb0a5d9c900740f64201f765bf Mon Sep 17 00:00:00 2001 From: narpra01 Date: Fri, 18 Jan 2019 16:53:53 +0000 Subject: IVGCVSW-2510 Ref workload implementation for Gather operator * add implemenentation for GatherQueueDescriptor validate function * add FirstInputTypedWorkload to allow type check on the first input tensor only * add ref workload implemenentation for float and uint8 * add Gather layer support in Ref * unit tests Change-Id: I4578a3211f11d24aa29d15bcf7f45b0445bcd1ee --- .../reference/workloads/RefGatherWorkload.cpp | 37 ++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 src/backends/reference/workloads/RefGatherWorkload.cpp (limited to 'src/backends/reference/workloads/RefGatherWorkload.cpp') diff --git a/src/backends/reference/workloads/RefGatherWorkload.cpp b/src/backends/reference/workloads/RefGatherWorkload.cpp new file mode 100644 index 0000000000..49b37cb1ac --- /dev/null +++ b/src/backends/reference/workloads/RefGatherWorkload.cpp @@ -0,0 +1,37 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefGatherWorkload.hpp" + +#include "Gather.hpp" +#include "Profiling.hpp" +#include "RefWorkloadUtils.hpp" +#include "TypeUtils.hpp" + +namespace armnn +{ + +template +void RefGatherWorkload::Execute() const +{ + using T = ResolveType; + + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefGatherWorkload_Execute"); + + const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); + + const T* paramsData = GetInputTensorData(0, m_Data); + const int32_t* indicesData = GetInputTensorData(1, m_Data); + T* outputData = GetOutputTensorData(0, m_Data); + + Gather(inputInfo0, inputInfo1, outputInfo, paramsData, indicesData, outputData); +} + +template class RefGatherWorkload; +template class RefGatherWorkload; + +} //namespace armnn -- cgit v1.2.1