aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
authorNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2018-11-23 15:33:41 +0000
committernattapat.chaimanowong <nattapat.chaimanowong@arm.com>2018-11-26 16:43:23 +0000
commit1216b585e085bc3aa0941b7dea6e263e978cb22c (patch)
treebd04eb4cc92383f788acb79489ecce81ef1bfac5 /src/backends/reference
parent144c01b56d5e8b2f9d8e84a03f7fe975888ee25a (diff)
downloadarmnn-1216b585e085bc3aa0941b7dea6e263e978cb22c.tar.gz
IVGCVSW-2087 Reference implementation and unit tests for StridedSlice
Change-Id: Ifeacc0adb4547c72537b9ea7a61bf3c4ec3673fa
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp13
-rw-r--r--src/backends/reference/RefLayerSupport.hpp5
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp2
-rw-r--r--src/backends/reference/backend.mk2
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp21
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt4
-rw-r--r--src/backends/reference/workloads/RefStridedSliceWorkload.cpp34
-rw-r--r--src/backends/reference/workloads/RefStridedSliceWorkload.hpp34
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp1
-rw-r--r--src/backends/reference/workloads/StridedSlice.cpp158
-rw-r--r--src/backends/reference/workloads/StridedSlice.hpp21
11 files changed, 294 insertions, 1 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 43a2fa2d07..00e4c5c09c 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -462,6 +462,19 @@ bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
&TrueFunc<>);
}
+bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const StridedSliceDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ ignore_unused(output);
+ ignore_unused(descriptor);
+ return IsSupportedForDataTypeRef(reasonIfUnsupported,
+ input.GetDataType(),
+ &TrueFunc<>,
+ &TrueFunc<>);
+}
+
bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index a03c89c48c..defa962847 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -174,6 +174,11 @@ public:
const ViewsDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsStridedSliceSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const StridedSliceDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsSubtractionSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 6d51b3d039..da8669ce6b 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -279,7 +279,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateBatchToSpaceNd(const BatchT
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+ return MakeWorkload<RefStridedSliceFloat32Workload, RefStridedSliceUint8Workload>(descriptor, info);
}
} // namespace armnn
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index 7d56144f18..7162d4a81e 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -58,10 +58,12 @@ BACKEND_SOURCES := \
workloads/RefSoftmaxFloat32Workload.cpp \
workloads/RefSoftmaxUint8Workload.cpp \
workloads/RefSpaceToBatchNdWorkload.cpp \
+ workloads/RefStridedSliceWorkload.cpp \
workloads/RefSplitterFloat32Workload.cpp \
workloads/RefSplitterUint8Workload.cpp \
workloads/ResizeBilinear.cpp \
workloads/SpaceToBatchNd.cpp \
+ workloads/StridedSlice.cpp \
workloads/Softmax.cpp
# BACKEND_TEST_SOURCES contains the list of files to be included
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index aba9f3eb0e..b16849929e 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -412,4 +412,25 @@ ARMNN_AUTO_TEST_CASE(BatchToSpaceNdNchwFloat321, BatchToSpaceNdNchwFloat32Test1)
ARMNN_AUTO_TEST_CASE(BatchToSpaceNdNhwcUint1, BatchToSpaceNdNhwcUintTest1)
+// Strided Slice
+ARMNN_AUTO_TEST_CASE(StridedSlice4DFloat32, StridedSlice4DFloat32Test)
+ARMNN_AUTO_TEST_CASE(StridedSlice4DReverseFloat32, StridedSlice4DReverseFloat32Test)
+ARMNN_AUTO_TEST_CASE(StridedSliceSimpleStrideFloat32, StridedSliceSimpleStrideFloat32Test)
+ARMNN_AUTO_TEST_CASE(StridedSliceSimpleRangeMaskFloat32, StridedSliceSimpleRangeMaskFloat32Test)
+ARMNN_AUTO_TEST_CASE(StridedSliceShrinkAxisMaskFloat32, StridedSliceShrinkAxisMaskFloat32Test)
+ARMNN_AUTO_TEST_CASE(StridedSlice3DFloat32, StridedSlice3DFloat32Test)
+ARMNN_AUTO_TEST_CASE(StridedSlice3DReverseFloat32, StridedSlice3DReverseFloat32Test)
+ARMNN_AUTO_TEST_CASE(StridedSlice2DFloat32, StridedSlice2DFloat32Test)
+ARMNN_AUTO_TEST_CASE(StridedSlice2DReverseFloat32, StridedSlice2DReverseFloat32Test)
+
+ARMNN_AUTO_TEST_CASE(StridedSlice4DUint8, StridedSlice4DUint8Test)
+ARMNN_AUTO_TEST_CASE(StridedSlice4DReverseUint8, StridedSlice4DReverseUint8Test)
+ARMNN_AUTO_TEST_CASE(StridedSliceSimpleStrideUint8, StridedSliceSimpleStrideUint8Test)
+ARMNN_AUTO_TEST_CASE(StridedSliceSimpleRangeMaskUint8, StridedSliceSimpleRangeMaskUint8Test)
+ARMNN_AUTO_TEST_CASE(StridedSliceShrinkAxisMaskUint8, StridedSliceShrinkAxisMaskUint8Test)
+ARMNN_AUTO_TEST_CASE(StridedSlice3DUint8, StridedSlice3DUint8Test)
+ARMNN_AUTO_TEST_CASE(StridedSlice3DReverseUint8, StridedSlice3DReverseUint8Test)
+ARMNN_AUTO_TEST_CASE(StridedSlice2DUint8, StridedSlice2DUint8Test)
+ARMNN_AUTO_TEST_CASE(StridedSlice2DReverseUint8, StridedSlice2DReverseUint8Test)
+
BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 1c38509ca0..2d9ad926f7 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -98,6 +98,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefSplitterFloat32Workload.hpp
RefSplitterUint8Workload.cpp
RefSplitterUint8Workload.hpp
+ RefStridedSliceWorkload.cpp
+ RefStridedSliceWorkload.hpp
RefWorkloads.hpp
RefWorkloadUtils.hpp
ResizeBilinear.cpp
@@ -107,6 +109,8 @@ list(APPEND armnnRefBackendWorkloads_sources
SpaceToBatchNd.hpp
SpaceToBatchNd.cpp
Splitter.hpp
+ StridedSlice.hpp
+ StridedSlice.cpp
TensorBufferArrayView.hpp
Mean.cpp
Mean.hpp
diff --git a/src/backends/reference/workloads/RefStridedSliceWorkload.cpp b/src/backends/reference/workloads/RefStridedSliceWorkload.cpp
new file mode 100644
index 0000000000..26a878e02f
--- /dev/null
+++ b/src/backends/reference/workloads/RefStridedSliceWorkload.cpp
@@ -0,0 +1,34 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefStridedSliceWorkload.hpp"
+#include "StridedSlice.hpp"
+
+#include "RefWorkloadUtils.hpp"
+#include "TypeUtils.hpp"
+
+namespace armnn
+{
+
+template<armnn::DataType DataType>
+void RefStridedSliceWorkload<DataType>::Execute() const
+{
+ using T = ResolveType<DataType>;
+
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, GetName() + "_Execute");
+
+ const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ const T* inputData = GetInputTensorData<T>(0, m_Data);
+ T* outputData = GetOutputTensorData<T>(0, m_Data);
+
+ StridedSlice(inputInfo, outputInfo, m_Data.m_Parameters, inputData, outputData);
+}
+
+template class RefStridedSliceWorkload<DataType::Float32>;
+template class RefStridedSliceWorkload<DataType::QuantisedAsymm8>;
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefStridedSliceWorkload.hpp b/src/backends/reference/workloads/RefStridedSliceWorkload.hpp
new file mode 100644
index 0000000000..b3586adbda
--- /dev/null
+++ b/src/backends/reference/workloads/RefStridedSliceWorkload.hpp
@@ -0,0 +1,34 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <backendsCommon/Workload.hpp>
+
+#include <armnn/TypesUtils.hpp>
+
+namespace armnn
+{
+
+template <armnn::DataType DataType>
+class RefStridedSliceWorkload : public TypedWorkload<StridedSliceQueueDescriptor, DataType>
+{
+public:
+ static const std::string& GetName()
+ {
+ static const std::string name = std::string("RefStridedSlice") + GetDataTypeName(DataType) + "Workload";
+ return name;
+ }
+
+ using TypedWorkload<StridedSliceQueueDescriptor, DataType>::m_Data;
+ using TypedWorkload<StridedSliceQueueDescriptor, DataType>::TypedWorkload;
+
+ void Execute() const override;
+};
+
+using RefStridedSliceFloat32Workload = RefStridedSliceWorkload<DataType::Float32>;
+using RefStridedSliceUint8Workload = RefStridedSliceWorkload<DataType::QuantisedAsymm8>;
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 5ea7fe4b58..20e9a9f5d3 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -43,6 +43,7 @@
#include "Merger.hpp"
#include "RefSpaceToBatchNdWorkload.hpp"
#include "RefSplitterFloat32Workload.hpp"
+#include "RefStridedSliceWorkload.hpp"
#include "RefConstantFloat32Workload.hpp"
#include "RefActivationFloat32Workload.hpp"
#include "RefConvolution2dFloat32Workload.hpp"
diff --git a/src/backends/reference/workloads/StridedSlice.cpp b/src/backends/reference/workloads/StridedSlice.cpp
new file mode 100644
index 0000000000..71903e421d
--- /dev/null
+++ b/src/backends/reference/workloads/StridedSlice.cpp
@@ -0,0 +1,158 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "StridedSlice.hpp"
+
+#include <boost/assert.hpp>
+#include <boost/numeric/conversion/cast.hpp>
+
+namespace armnn
+{
+
+void PadParams(StridedSliceDescriptor& p, unsigned int dimCount)
+{
+ BOOST_ASSERT_MSG(dimCount <= 4, "Expected input with at most 4 dimensions");
+
+ const unsigned int beginIndicesCount =
+ boost::numeric_cast<unsigned int>(p.m_Begin.size());
+
+ BOOST_ASSERT(dimCount >= beginIndicesCount);
+ const unsigned int padCount = dimCount - beginIndicesCount;
+
+ p.m_Begin.resize(dimCount);
+ p.m_End.resize(dimCount);
+ p.m_Stride.resize(dimCount);
+
+ for (unsigned int i = beginIndicesCount; i > 0; --i)
+ {
+ p.m_Stride[i + padCount - 1] = p.m_Stride[i - 1];
+ p.m_Begin[i + padCount - 1] = p.m_Begin[i - 1];
+ p.m_End[i + padCount - 1] = p.m_End[i - 1];
+ }
+
+ for (unsigned int i = 0; i < padCount; ++i)
+ {
+ p.m_Stride[i] = 1;
+ p.m_Begin[i] = 0;
+ p.m_End[i] = 0;
+ }
+
+ p.m_ShrinkAxisMask <<= padCount;
+ p.m_EllipsisMask <<= padCount;
+ p.m_NewAxisMask <<= padCount;
+ p.m_BeginMask <<= padCount;
+ p.m_EndMask <<= padCount;
+ p.m_BeginMask |= (1 << padCount) - 1;
+ p.m_EndMask |= (1 << padCount) - 1;
+}
+
+bool LoopCondition(int index, int stop, int stride)
+{
+ return stride > 0 ? index >= stop : index <= stop;
+}
+
+TensorShape ExtendShape(const TensorShape& inputShape,
+ unsigned int newNumDimensions)
+{
+ if (inputShape.GetNumDimensions() >= newNumDimensions)
+ {
+ return inputShape;
+ }
+
+ unsigned int newSizes[newNumDimensions];
+
+ unsigned int diff = newNumDimensions - inputShape.GetNumDimensions();
+
+ for (unsigned int i = 0; i < diff; i++)
+ {
+ newSizes[i] = 1;
+ }
+
+ for (unsigned int i = diff; i < newNumDimensions; i++)
+ {
+ newSizes[i] = inputShape[i - diff];
+ }
+
+ return TensorShape(newNumDimensions, newSizes);
+}
+
+template<typename T>
+void StridedSlice(const TensorInfo& inputInfo,
+ const TensorInfo& outputInfo,
+ const StridedSliceDescriptor& params,
+ const T* inputData,
+ T* outputData)
+{
+ const TensorShape inputShape =
+ ExtendShape(inputInfo.GetShape(), 4);
+
+ StridedSliceDescriptor paddedParams = params;
+
+ // Pad parameters to 4 dimensions
+ PadParams(paddedParams, 4);
+
+ const int start0 =
+ paddedParams.GetStartForAxis(inputShape, 0);
+ const int stop0 =
+ paddedParams.GetStopForAxis(inputShape, 0, start0);
+
+ const int start1 =
+ paddedParams.GetStartForAxis(inputShape, 1);
+ const int stop1 =
+ paddedParams.GetStopForAxis(inputShape, 1, start1);
+
+ const int start2 =
+ paddedParams.GetStartForAxis(inputShape, 2);
+ const int stop2 =
+ paddedParams.GetStopForAxis(inputShape, 2, start2);
+
+ const int start3 =
+ paddedParams.GetStartForAxis(inputShape, 3);
+ const int stop3 =
+ paddedParams.GetStopForAxis(inputShape, 3, start3);
+
+ T* outPtr = outputData;
+
+ for (int in0 = start0;
+ !LoopCondition(in0, stop0, paddedParams.m_Stride[0]);
+ in0 += paddedParams.m_Stride[0])
+ {
+ for (int in1 = start1;
+ !LoopCondition(in1, stop1, paddedParams.m_Stride[1]);
+ in1 += paddedParams.m_Stride[1])
+ {
+ for (int in2 = start2;
+ !LoopCondition(in2, stop2, paddedParams.m_Stride[2]);
+ in2 += paddedParams.m_Stride[2])
+ {
+ for (int in3 = start3;
+ !LoopCondition(in3, stop3, paddedParams.m_Stride[3]);
+ in3 += paddedParams.m_Stride[3])
+ {
+ int dim1 = boost::numeric_cast<int>(inputShape[1]);
+ int dim2 = boost::numeric_cast<int>(inputShape[2]);
+ int dim3 = boost::numeric_cast<int>(inputShape[3]);
+
+ int inputOffset = ((in0 * dim1 + in1) * dim2 + in2) * dim3 + in3;
+ *(outPtr++) = inputData[inputOffset];
+ }
+ }
+ }
+ }
+}
+
+template void StridedSlice<float>(const TensorInfo& inputInfo,
+ const TensorInfo& outputInfo,
+ const StridedSliceDescriptor& params,
+ const float* inputData,
+ float* outData);
+
+template void StridedSlice<uint8_t>(const TensorInfo& inputInfo,
+ const TensorInfo& outputInfo,
+ const StridedSliceDescriptor& params,
+ const uint8_t* inputData,
+ uint8_t* outData);
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/StridedSlice.hpp b/src/backends/reference/workloads/StridedSlice.hpp
new file mode 100644
index 0000000000..8eed8706dc
--- /dev/null
+++ b/src/backends/reference/workloads/StridedSlice.hpp
@@ -0,0 +1,21 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/Descriptors.hpp>
+#include <armnn/Tensor.hpp>
+
+namespace armnn
+{
+
+template <typename T>
+void StridedSlice(const TensorInfo& inputInfo,
+ const TensorInfo& outputInfo,
+ const StridedSliceDescriptor& params,
+ const T* inputData,
+ T* outputData);
+
+} //namespace armnn