aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads
diff options
context:
space:
mode:
authornarpra01 <narumol.prangnawarat@arm.com>2019-01-18 16:53:53 +0000
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2019-01-22 17:46:51 +0000
commit4951d84b1174a4bb0a5d9c900740f64201f765bf (patch)
treefe713740ac0acbaa8d74bbc9cdb450d08ef9f575 /src/backends/reference/workloads
parent0edd46737065d3e5c09aa959807e81f9836ee709 (diff)
downloadarmnn-4951d84b1174a4bb0a5d9c900740f64201f765bf.tar.gz
IVGCVSW-2510 Ref workload implementation for Gather operator
* add implemenentation for GatherQueueDescriptor validate function * add FirstInputTypedWorkload to allow type check on the first input tensor only * add ref workload implemenentation for float and uint8 * add Gather layer support in Ref * unit tests Change-Id: I4578a3211f11d24aa29d15bcf7f45b0445bcd1ee
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt4
-rw-r--r--src/backends/reference/workloads/Gather.cpp64
-rw-r--r--src/backends/reference/workloads/Gather.hpp21
-rw-r--r--src/backends/reference/workloads/RefGatherWorkload.cpp37
-rw-r--r--src/backends/reference/workloads/RefGatherWorkload.hpp36
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp2
6 files changed, 164 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index d15f77d6e4..583c89a5b4 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -19,6 +19,8 @@ list(APPEND armnnRefBackendWorkloads_sources
ElementwiseFunction.hpp
FullyConnected.cpp
FullyConnected.hpp
+ Gather.cpp
+ Gather.hpp
Maximum.hpp
Merger.hpp
Minimum.hpp
@@ -68,6 +70,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefFullyConnectedFloat32Workload.hpp
RefFullyConnectedUint8Workload.cpp
RefFullyConnectedUint8Workload.hpp
+ RefGatherWorkload.cpp
+ RefGatherWorkload.hpp
RefL2NormalizationFloat32Workload.cpp
RefL2NormalizationFloat32Workload.hpp
RefLstmFloat32Workload.cpp
diff --git a/src/backends/reference/workloads/Gather.cpp b/src/backends/reference/workloads/Gather.cpp
new file mode 100644
index 0000000000..b195003e04
--- /dev/null
+++ b/src/backends/reference/workloads/Gather.cpp
@@ -0,0 +1,64 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "Gather.hpp"
+
+#include "RefWorkloadUtils.hpp"
+
+#include <backendsCommon/WorkloadData.hpp>
+
+namespace armnn
+{
+
+template <typename T>
+void Gather(const TensorInfo& paramsInfo,
+ const TensorInfo& indicesInfo,
+ const TensorInfo& outputInfo,
+ const T* params,
+ const int32_t* indices,
+ T* output)
+{
+ const TensorShape& paramsShape = paramsInfo.GetShape();
+
+ unsigned int paramsProduct = 1;
+ for (unsigned int i = 1; i < paramsInfo.GetNumDimensions(); ++i)
+ {
+ paramsProduct = paramsProduct * paramsShape[i];
+ }
+
+ unsigned int outIndex = 0;
+ for (unsigned int i = 0; i < indicesInfo.GetNumElements(); ++i)
+ {
+ unsigned int indx = boost::numeric_cast<unsigned int>(indices[i]);
+
+ BOOST_ASSERT(indices[i] >= 0 && indx < paramsShape[0]);
+
+ unsigned int startOffset = indx * paramsProduct;
+ unsigned int endOffset = startOffset + paramsProduct;
+ for (unsigned int j = startOffset; j < endOffset; ++j)
+ {
+ output[outIndex] = params[j];
+ ++outIndex;
+ }
+ }
+
+ BOOST_ASSERT(outIndex == outputInfo.GetNumElements());
+}
+
+template void Gather<float>(const TensorInfo& paramsInfo,
+ const TensorInfo& indicesInfo,
+ const TensorInfo& outputInfo,
+ const float* params,
+ const int32_t* indices,
+ float* output);
+
+template void Gather<uint8_t>(const TensorInfo& paramsInfo,
+ const TensorInfo& indicesInfo,
+ const TensorInfo& outputInfo,
+ const uint8_t* params,
+ const int32_t* indices,
+ uint8_t* output);
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/Gather.hpp b/src/backends/reference/workloads/Gather.hpp
new file mode 100644
index 0000000000..0ad4f8ceb6
--- /dev/null
+++ b/src/backends/reference/workloads/Gather.hpp
@@ -0,0 +1,21 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "armnn/Tensor.hpp"
+
+namespace armnn
+{
+
+template <typename T>
+void Gather(const TensorInfo& paramsInfo,
+ const TensorInfo& indicesInfo,
+ const TensorInfo& outputInfo,
+ const T* params,
+ const int32_t* indices,
+ T* output);
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefGatherWorkload.cpp b/src/backends/reference/workloads/RefGatherWorkload.cpp
new file mode 100644
index 0000000000..49b37cb1ac
--- /dev/null
+++ b/src/backends/reference/workloads/RefGatherWorkload.cpp
@@ -0,0 +1,37 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefGatherWorkload.hpp"
+
+#include "Gather.hpp"
+#include "Profiling.hpp"
+#include "RefWorkloadUtils.hpp"
+#include "TypeUtils.hpp"
+
+namespace armnn
+{
+
+template <armnn::DataType DataType>
+void RefGatherWorkload<DataType>::Execute() const
+{
+ using T = ResolveType<DataType>;
+
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefGatherWorkload_Execute");
+
+ const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ const T* paramsData = GetInputTensorData<T>(0, m_Data);
+ const int32_t* indicesData = GetInputTensorData<int32_t>(1, m_Data);
+ T* outputData = GetOutputTensorData<T>(0, m_Data);
+
+ Gather(inputInfo0, inputInfo1, outputInfo, paramsData, indicesData, outputData);
+}
+
+template class RefGatherWorkload<DataType::Float32>;
+template class RefGatherWorkload<DataType::QuantisedAsymm8>;
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefGatherWorkload.hpp b/src/backends/reference/workloads/RefGatherWorkload.hpp
new file mode 100644
index 0000000000..27827490e3
--- /dev/null
+++ b/src/backends/reference/workloads/RefGatherWorkload.hpp
@@ -0,0 +1,36 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+
+#include <armnn/TypesUtils.hpp>
+
+namespace armnn
+{
+
+template <armnn::DataType DataType>
+class RefGatherWorkload : public FirstInputTypedWorkload<GatherQueueDescriptor, DataType>
+{
+public:
+
+ static const std::string& GetName()
+ {
+ static const std::string name = std::string("RefGather") + GetDataTypeName(DataType) + "Workload";
+ return name;
+ }
+
+ using FirstInputTypedWorkload<GatherQueueDescriptor, DataType>::m_Data;
+ using FirstInputTypedWorkload<GatherQueueDescriptor, DataType>::FirstInputTypedWorkload;
+
+ void Execute() const override;
+};
+
+using RefGatherFloat32Workload = RefGatherWorkload<DataType::Float32>;
+using RefGatherUint8Workload = RefGatherWorkload<DataType::QuantisedAsymm8>;
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 8beb03fe32..8550ee583e 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -19,6 +19,7 @@
#include "RefWorkloadUtils.hpp"
#include "RefMergerUint8Workload.hpp"
#include "RefFullyConnectedFloat32Workload.hpp"
+#include "RefGatherWorkload.hpp"
#include "Softmax.hpp"
#include "RefMergerFloat32Workload.hpp"
#include "TensorBufferArrayView.hpp"
@@ -28,6 +29,7 @@
#include "RefReshapeFloat32Workload.hpp"
#include "RefDepthwiseConvolution2dUint8Workload.hpp"
#include "FullyConnected.hpp"
+#include "Gather.hpp"
#include "RefFloorFloat32Workload.hpp"
#include "RefSoftmaxFloat32Workload.hpp"
#include "RefSoftmaxUint8Workload.hpp"