diff options
Diffstat (limited to 'src/backends/reference/workloads')
6 files changed, 252 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 1c38509ca0..2d9ad926f7 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -98,6 +98,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefSplitterFloat32Workload.hpp RefSplitterUint8Workload.cpp RefSplitterUint8Workload.hpp + RefStridedSliceWorkload.cpp + RefStridedSliceWorkload.hpp RefWorkloads.hpp RefWorkloadUtils.hpp ResizeBilinear.cpp @@ -107,6 +109,8 @@ list(APPEND armnnRefBackendWorkloads_sources SpaceToBatchNd.hpp SpaceToBatchNd.cpp Splitter.hpp + StridedSlice.hpp + StridedSlice.cpp TensorBufferArrayView.hpp Mean.cpp Mean.hpp diff --git a/src/backends/reference/workloads/RefStridedSliceWorkload.cpp b/src/backends/reference/workloads/RefStridedSliceWorkload.cpp new file mode 100644 index 0000000000..26a878e02f --- /dev/null +++ b/src/backends/reference/workloads/RefStridedSliceWorkload.cpp @@ -0,0 +1,34 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefStridedSliceWorkload.hpp" +#include "StridedSlice.hpp" + +#include "RefWorkloadUtils.hpp" +#include "TypeUtils.hpp" + +namespace armnn +{ + +template<armnn::DataType DataType> +void RefStridedSliceWorkload<DataType>::Execute() const +{ + using T = ResolveType<DataType>; + + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, GetName() + "_Execute"); + + const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); + + const T* inputData = GetInputTensorData<T>(0, m_Data); + T* outputData = GetOutputTensorData<T>(0, m_Data); + + StridedSlice(inputInfo, outputInfo, m_Data.m_Parameters, inputData, outputData); +} + +template class RefStridedSliceWorkload<DataType::Float32>; +template class RefStridedSliceWorkload<DataType::QuantisedAsymm8>; + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefStridedSliceWorkload.hpp b/src/backends/reference/workloads/RefStridedSliceWorkload.hpp new file mode 100644 index 0000000000..b3586adbda --- /dev/null +++ b/src/backends/reference/workloads/RefStridedSliceWorkload.hpp @@ -0,0 +1,34 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <backendsCommon/Workload.hpp> + +#include <armnn/TypesUtils.hpp> + +namespace armnn +{ + +template <armnn::DataType DataType> +class RefStridedSliceWorkload : public TypedWorkload<StridedSliceQueueDescriptor, DataType> +{ +public: + static const std::string& GetName() + { + static const std::string name = std::string("RefStridedSlice") + GetDataTypeName(DataType) + "Workload"; + return name; + } + + using TypedWorkload<StridedSliceQueueDescriptor, DataType>::m_Data; + using TypedWorkload<StridedSliceQueueDescriptor, DataType>::TypedWorkload; + + void Execute() const override; +}; + +using RefStridedSliceFloat32Workload = RefStridedSliceWorkload<DataType::Float32>; +using RefStridedSliceUint8Workload = RefStridedSliceWorkload<DataType::QuantisedAsymm8>; + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 5ea7fe4b58..20e9a9f5d3 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -43,6 +43,7 @@ #include "Merger.hpp" #include "RefSpaceToBatchNdWorkload.hpp" #include "RefSplitterFloat32Workload.hpp" +#include "RefStridedSliceWorkload.hpp" #include "RefConstantFloat32Workload.hpp" #include "RefActivationFloat32Workload.hpp" #include "RefConvolution2dFloat32Workload.hpp" diff --git a/src/backends/reference/workloads/StridedSlice.cpp b/src/backends/reference/workloads/StridedSlice.cpp new file mode 100644 index 0000000000..71903e421d --- /dev/null +++ b/src/backends/reference/workloads/StridedSlice.cpp @@ -0,0 +1,158 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "StridedSlice.hpp" + +#include <boost/assert.hpp> +#include <boost/numeric/conversion/cast.hpp> + +namespace armnn +{ + +void PadParams(StridedSliceDescriptor& p, unsigned int dimCount) +{ + BOOST_ASSERT_MSG(dimCount <= 4, "Expected input with at most 4 dimensions"); + + const unsigned int beginIndicesCount = + boost::numeric_cast<unsigned int>(p.m_Begin.size()); + + BOOST_ASSERT(dimCount >= beginIndicesCount); + const unsigned int padCount = dimCount - beginIndicesCount; + + p.m_Begin.resize(dimCount); + p.m_End.resize(dimCount); + p.m_Stride.resize(dimCount); + + for (unsigned int i = beginIndicesCount; i > 0; --i) + { + p.m_Stride[i + padCount - 1] = p.m_Stride[i - 1]; + p.m_Begin[i + padCount - 1] = p.m_Begin[i - 1]; + p.m_End[i + padCount - 1] = p.m_End[i - 1]; + } + + for (unsigned int i = 0; i < padCount; ++i) + { + p.m_Stride[i] = 1; + p.m_Begin[i] = 0; + p.m_End[i] = 0; + } + + p.m_ShrinkAxisMask <<= padCount; + p.m_EllipsisMask <<= padCount; + p.m_NewAxisMask <<= padCount; + p.m_BeginMask <<= padCount; + p.m_EndMask <<= padCount; + p.m_BeginMask |= (1 << padCount) - 1; + p.m_EndMask |= (1 << padCount) - 1; +} + +bool LoopCondition(int index, int stop, int stride) +{ + return stride > 0 ? index >= stop : index <= stop; +} + +TensorShape ExtendShape(const TensorShape& inputShape, + unsigned int newNumDimensions) +{ + if (inputShape.GetNumDimensions() >= newNumDimensions) + { + return inputShape; + } + + unsigned int newSizes[newNumDimensions]; + + unsigned int diff = newNumDimensions - inputShape.GetNumDimensions(); + + for (unsigned int i = 0; i < diff; i++) + { + newSizes[i] = 1; + } + + for (unsigned int i = diff; i < newNumDimensions; i++) + { + newSizes[i] = inputShape[i - diff]; + } + + return TensorShape(newNumDimensions, newSizes); +} + +template<typename T> +void StridedSlice(const TensorInfo& inputInfo, + const TensorInfo& outputInfo, + const StridedSliceDescriptor& params, + const T* inputData, + T* outputData) +{ + const TensorShape inputShape = + ExtendShape(inputInfo.GetShape(), 4); + + StridedSliceDescriptor paddedParams = params; + + // Pad parameters to 4 dimensions + PadParams(paddedParams, 4); + + const int start0 = + paddedParams.GetStartForAxis(inputShape, 0); + const int stop0 = + paddedParams.GetStopForAxis(inputShape, 0, start0); + + const int start1 = + paddedParams.GetStartForAxis(inputShape, 1); + const int stop1 = + paddedParams.GetStopForAxis(inputShape, 1, start1); + + const int start2 = + paddedParams.GetStartForAxis(inputShape, 2); + const int stop2 = + paddedParams.GetStopForAxis(inputShape, 2, start2); + + const int start3 = + paddedParams.GetStartForAxis(inputShape, 3); + const int stop3 = + paddedParams.GetStopForAxis(inputShape, 3, start3); + + T* outPtr = outputData; + + for (int in0 = start0; + !LoopCondition(in0, stop0, paddedParams.m_Stride[0]); + in0 += paddedParams.m_Stride[0]) + { + for (int in1 = start1; + !LoopCondition(in1, stop1, paddedParams.m_Stride[1]); + in1 += paddedParams.m_Stride[1]) + { + for (int in2 = start2; + !LoopCondition(in2, stop2, paddedParams.m_Stride[2]); + in2 += paddedParams.m_Stride[2]) + { + for (int in3 = start3; + !LoopCondition(in3, stop3, paddedParams.m_Stride[3]); + in3 += paddedParams.m_Stride[3]) + { + int dim1 = boost::numeric_cast<int>(inputShape[1]); + int dim2 = boost::numeric_cast<int>(inputShape[2]); + int dim3 = boost::numeric_cast<int>(inputShape[3]); + + int inputOffset = ((in0 * dim1 + in1) * dim2 + in2) * dim3 + in3; + *(outPtr++) = inputData[inputOffset]; + } + } + } + } +} + +template void StridedSlice<float>(const TensorInfo& inputInfo, + const TensorInfo& outputInfo, + const StridedSliceDescriptor& params, + const float* inputData, + float* outData); + +template void StridedSlice<uint8_t>(const TensorInfo& inputInfo, + const TensorInfo& outputInfo, + const StridedSliceDescriptor& params, + const uint8_t* inputData, + uint8_t* outData); + +} //namespace armnn diff --git a/src/backends/reference/workloads/StridedSlice.hpp b/src/backends/reference/workloads/StridedSlice.hpp new file mode 100644 index 0000000000..8eed8706dc --- /dev/null +++ b/src/backends/reference/workloads/StridedSlice.hpp @@ -0,0 +1,21 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <armnn/Descriptors.hpp> +#include <armnn/Tensor.hpp> + +namespace armnn +{ + +template <typename T> +void StridedSlice(const TensorInfo& inputInfo, + const TensorInfo& outputInfo, + const StridedSliceDescriptor& params, + const T* inputData, + T* outputData); + +} //namespace armnn |