aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp13
-rw-r--r--src/backends/reference/RefLayerSupport.hpp5
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp12
-rw-r--r--src/backends/reference/backend.mk2
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp8
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt4
-rw-r--r--src/backends/reference/workloads/Gather.cpp64
-rw-r--r--src/backends/reference/workloads/Gather.hpp21
-rw-r--r--src/backends/reference/workloads/RefGatherWorkload.cpp37
-rw-r--r--src/backends/reference/workloads/RefGatherWorkload.hpp36
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp2
11 files changed, 198 insertions, 6 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 61a34f957e..ce81f8d38a 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -257,6 +257,19 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
&TrueFunc<>);
}
+bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
+ const armnn::TensorInfo& input1,
+ const armnn::TensorInfo& output,
+ armnn::Optional<std::string&> reasonIfUnsupported) const
+{
+ ignore_unused(input1);
+ ignore_unused(output);
+ return IsSupportedForDataTypeRef(reasonIfUnsupported,
+ input0.GetDataType(),
+ &TrueFunc<>,
+ &TrueFunc<>);
+}
+
bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index 5778806f00..01abc73dd5 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -91,6 +91,11 @@ public:
const FullyConnectedDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsGatherSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsGreaterSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index cb7d6ea01a..9bdda9d128 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -318,16 +318,16 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateRsqrt(const RsqrtQueueDescr
return MakeWorkload<RefRsqrtFloat32Workload, NullWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGather(const armnn::GatherQueueDescriptor& descriptor,
+ const armnn::WorkloadInfo& info) const
{
- return nullptr;
+ return MakeWorkload<RefGatherFloat32Workload, RefGatherUint8Workload>(descriptor, info);
}
-std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGather(const armnn::GatherQueueDescriptor& descriptor,
- const armnn::WorkloadInfo& info) const
+std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+ return nullptr;
}
} // namespace armnn
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index 84f15c9c80..8dd6a51139 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -18,6 +18,7 @@ BACKEND_SOURCES := \
workloads/Debug.cpp \
workloads/ElementwiseFunction.cpp \
workloads/FullyConnected.cpp \
+ workloads/Gather.cpp \
workloads/Mean.cpp \
workloads/Pad.cpp \
workloads/Pooling2d.cpp \
@@ -42,6 +43,7 @@ BACKEND_SOURCES := \
workloads/RefFloorFloat32Workload.cpp \
workloads/RefFullyConnectedFloat32Workload.cpp \
workloads/RefFullyConnectedUint8Workload.cpp \
+ workloads/RefGatherWorkload.cpp \
workloads/RefL2NormalizationFloat32Workload.cpp \
workloads/RefLstmFloat32Workload.cpp \
workloads/RefMeanFloat32Workload.cpp \
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 50c47aecf9..cfe02e673e 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -492,4 +492,12 @@ ARMNN_AUTO_TEST_CASE(Debug3DUint8, Debug3DUint8Test)
ARMNN_AUTO_TEST_CASE(Debug2DUint8, Debug2DUint8Test)
ARMNN_AUTO_TEST_CASE(Debug1DUint8, Debug1DUint8Test)
+// Gather
+ARMNN_AUTO_TEST_CASE(Gather1DParamsFloat, Gather1DParamsFloatTest)
+ARMNN_AUTO_TEST_CASE(Gather1DParamsUint8, Gather1DParamsUint8Test)
+ARMNN_AUTO_TEST_CASE(GatherMultiDimParamsFloat, GatherMultiDimParamsFloatTest)
+ARMNN_AUTO_TEST_CASE(GatherMultiDimParamsUint8, GatherMultiDimParamsUint8Test)
+ARMNN_AUTO_TEST_CASE(GatherMultiDimParamsMultiDimIndicesFloat, GatherMultiDimParamsMultiDimIndicesFloatTest)
+ARMNN_AUTO_TEST_CASE(GatherMultiDimParamsMultiDimIndicesUint8, GatherMultiDimParamsMultiDimIndicesUint8Test)
+
BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index d15f77d6e4..583c89a5b4 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -19,6 +19,8 @@ list(APPEND armnnRefBackendWorkloads_sources
ElementwiseFunction.hpp
FullyConnected.cpp
FullyConnected.hpp
+ Gather.cpp
+ Gather.hpp
Maximum.hpp
Merger.hpp
Minimum.hpp
@@ -68,6 +70,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefFullyConnectedFloat32Workload.hpp
RefFullyConnectedUint8Workload.cpp
RefFullyConnectedUint8Workload.hpp
+ RefGatherWorkload.cpp
+ RefGatherWorkload.hpp
RefL2NormalizationFloat32Workload.cpp
RefL2NormalizationFloat32Workload.hpp
RefLstmFloat32Workload.cpp
diff --git a/src/backends/reference/workloads/Gather.cpp b/src/backends/reference/workloads/Gather.cpp
new file mode 100644
index 0000000000..b195003e04
--- /dev/null
+++ b/src/backends/reference/workloads/Gather.cpp
@@ -0,0 +1,64 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "Gather.hpp"
+
+#include "RefWorkloadUtils.hpp"
+
+#include <backendsCommon/WorkloadData.hpp>
+
+namespace armnn
+{
+
+template <typename T>
+void Gather(const TensorInfo& paramsInfo,
+ const TensorInfo& indicesInfo,
+ const TensorInfo& outputInfo,
+ const T* params,
+ const int32_t* indices,
+ T* output)
+{
+ const TensorShape& paramsShape = paramsInfo.GetShape();
+
+ unsigned int paramsProduct = 1;
+ for (unsigned int i = 1; i < paramsInfo.GetNumDimensions(); ++i)
+ {
+ paramsProduct = paramsProduct * paramsShape[i];
+ }
+
+ unsigned int outIndex = 0;
+ for (unsigned int i = 0; i < indicesInfo.GetNumElements(); ++i)
+ {
+ unsigned int indx = boost::numeric_cast<unsigned int>(indices[i]);
+
+ BOOST_ASSERT(indices[i] >= 0 && indx < paramsShape[0]);
+
+ unsigned int startOffset = indx * paramsProduct;
+ unsigned int endOffset = startOffset + paramsProduct;
+ for (unsigned int j = startOffset; j < endOffset; ++j)
+ {
+ output[outIndex] = params[j];
+ ++outIndex;
+ }
+ }
+
+ BOOST_ASSERT(outIndex == outputInfo.GetNumElements());
+}
+
+template void Gather<float>(const TensorInfo& paramsInfo,
+ const TensorInfo& indicesInfo,
+ const TensorInfo& outputInfo,
+ const float* params,
+ const int32_t* indices,
+ float* output);
+
+template void Gather<uint8_t>(const TensorInfo& paramsInfo,
+ const TensorInfo& indicesInfo,
+ const TensorInfo& outputInfo,
+ const uint8_t* params,
+ const int32_t* indices,
+ uint8_t* output);
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/Gather.hpp b/src/backends/reference/workloads/Gather.hpp
new file mode 100644
index 0000000000..0ad4f8ceb6
--- /dev/null
+++ b/src/backends/reference/workloads/Gather.hpp
@@ -0,0 +1,21 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "armnn/Tensor.hpp"
+
+namespace armnn
+{
+
+template <typename T>
+void Gather(const TensorInfo& paramsInfo,
+ const TensorInfo& indicesInfo,
+ const TensorInfo& outputInfo,
+ const T* params,
+ const int32_t* indices,
+ T* output);
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefGatherWorkload.cpp b/src/backends/reference/workloads/RefGatherWorkload.cpp
new file mode 100644
index 0000000000..49b37cb1ac
--- /dev/null
+++ b/src/backends/reference/workloads/RefGatherWorkload.cpp
@@ -0,0 +1,37 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefGatherWorkload.hpp"
+
+#include "Gather.hpp"
+#include "Profiling.hpp"
+#include "RefWorkloadUtils.hpp"
+#include "TypeUtils.hpp"
+
+namespace armnn
+{
+
+template <armnn::DataType DataType>
+void RefGatherWorkload<DataType>::Execute() const
+{
+ using T = ResolveType<DataType>;
+
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefGatherWorkload_Execute");
+
+ const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ const T* paramsData = GetInputTensorData<T>(0, m_Data);
+ const int32_t* indicesData = GetInputTensorData<int32_t>(1, m_Data);
+ T* outputData = GetOutputTensorData<T>(0, m_Data);
+
+ Gather(inputInfo0, inputInfo1, outputInfo, paramsData, indicesData, outputData);
+}
+
+template class RefGatherWorkload<DataType::Float32>;
+template class RefGatherWorkload<DataType::QuantisedAsymm8>;
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefGatherWorkload.hpp b/src/backends/reference/workloads/RefGatherWorkload.hpp
new file mode 100644
index 0000000000..27827490e3
--- /dev/null
+++ b/src/backends/reference/workloads/RefGatherWorkload.hpp
@@ -0,0 +1,36 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+
+#include <armnn/TypesUtils.hpp>
+
+namespace armnn
+{
+
+template <armnn::DataType DataType>
+class RefGatherWorkload : public FirstInputTypedWorkload<GatherQueueDescriptor, DataType>
+{
+public:
+
+ static const std::string& GetName()
+ {
+ static const std::string name = std::string("RefGather") + GetDataTypeName(DataType) + "Workload";
+ return name;
+ }
+
+ using FirstInputTypedWorkload<GatherQueueDescriptor, DataType>::m_Data;
+ using FirstInputTypedWorkload<GatherQueueDescriptor, DataType>::FirstInputTypedWorkload;
+
+ void Execute() const override;
+};
+
+using RefGatherFloat32Workload = RefGatherWorkload<DataType::Float32>;
+using RefGatherUint8Workload = RefGatherWorkload<DataType::QuantisedAsymm8>;
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 8beb03fe32..8550ee583e 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -19,6 +19,7 @@
#include "RefWorkloadUtils.hpp"
#include "RefMergerUint8Workload.hpp"
#include "RefFullyConnectedFloat32Workload.hpp"
+#include "RefGatherWorkload.hpp"
#include "Softmax.hpp"
#include "RefMergerFloat32Workload.hpp"
#include "TensorBufferArrayView.hpp"
@@ -28,6 +29,7 @@
#include "RefReshapeFloat32Workload.hpp"
#include "RefDepthwiseConvolution2dUint8Workload.hpp"
#include "FullyConnected.hpp"
+#include "Gather.hpp"
#include "RefFloorFloat32Workload.hpp"
#include "RefSoftmaxFloat32Workload.hpp"
#include "RefSoftmaxUint8Workload.hpp"