aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/Gather.cpp
diff options
context:
space:
mode:
authornarpra01 <narumol.prangnawarat@arm.com>2019-01-18 16:53:53 +0000
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2019-01-22 17:46:51 +0000
commit4951d84b1174a4bb0a5d9c900740f64201f765bf (patch)
treefe713740ac0acbaa8d74bbc9cdb450d08ef9f575 /src/backends/reference/workloads/Gather.cpp
parent0edd46737065d3e5c09aa959807e81f9836ee709 (diff)
downloadarmnn-4951d84b1174a4bb0a5d9c900740f64201f765bf.tar.gz
IVGCVSW-2510 Ref workload implementation for Gather operator
* add implemenentation for GatherQueueDescriptor validate function * add FirstInputTypedWorkload to allow type check on the first input tensor only * add ref workload implemenentation for float and uint8 * add Gather layer support in Ref * unit tests Change-Id: I4578a3211f11d24aa29d15bcf7f45b0445bcd1ee
Diffstat (limited to 'src/backends/reference/workloads/Gather.cpp')
-rw-r--r--src/backends/reference/workloads/Gather.cpp64
1 files changed, 64 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/Gather.cpp b/src/backends/reference/workloads/Gather.cpp
new file mode 100644
index 0000000000..b195003e04
--- /dev/null
+++ b/src/backends/reference/workloads/Gather.cpp
@@ -0,0 +1,64 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "Gather.hpp"
+
+#include "RefWorkloadUtils.hpp"
+
+#include <backendsCommon/WorkloadData.hpp>
+
+namespace armnn
+{
+
+template <typename T>
+void Gather(const TensorInfo& paramsInfo,
+ const TensorInfo& indicesInfo,
+ const TensorInfo& outputInfo,
+ const T* params,
+ const int32_t* indices,
+ T* output)
+{
+ const TensorShape& paramsShape = paramsInfo.GetShape();
+
+ unsigned int paramsProduct = 1;
+ for (unsigned int i = 1; i < paramsInfo.GetNumDimensions(); ++i)
+ {
+ paramsProduct = paramsProduct * paramsShape[i];
+ }
+
+ unsigned int outIndex = 0;
+ for (unsigned int i = 0; i < indicesInfo.GetNumElements(); ++i)
+ {
+ unsigned int indx = boost::numeric_cast<unsigned int>(indices[i]);
+
+ BOOST_ASSERT(indices[i] >= 0 && indx < paramsShape[0]);
+
+ unsigned int startOffset = indx * paramsProduct;
+ unsigned int endOffset = startOffset + paramsProduct;
+ for (unsigned int j = startOffset; j < endOffset; ++j)
+ {
+ output[outIndex] = params[j];
+ ++outIndex;
+ }
+ }
+
+ BOOST_ASSERT(outIndex == outputInfo.GetNumElements());
+}
+
+template void Gather<float>(const TensorInfo& paramsInfo,
+ const TensorInfo& indicesInfo,
+ const TensorInfo& outputInfo,
+ const float* params,
+ const int32_t* indices,
+ float* output);
+
+template void Gather<uint8_t>(const TensorInfo& paramsInfo,
+ const TensorInfo& indicesInfo,
+ const TensorInfo& outputInfo,
+ const uint8_t* params,
+ const int32_t* indices,
+ uint8_t* output);
+
+} //namespace armnn