diff options
Diffstat (limited to 'src/backends/reference/workloads/StridedSlice.cpp')
-rw-r--r-- | src/backends/reference/workloads/StridedSlice.cpp | 67 |
1 files changed, 29 insertions, 38 deletions
diff --git a/src/backends/reference/workloads/StridedSlice.cpp b/src/backends/reference/workloads/StridedSlice.cpp index 71903e421d..9f2b1e76f6 100644 --- a/src/backends/reference/workloads/StridedSlice.cpp +++ b/src/backends/reference/workloads/StridedSlice.cpp @@ -5,12 +5,19 @@ #include "StridedSlice.hpp" +#include <ResolveType.hpp> + #include <boost/assert.hpp> #include <boost/numeric/conversion/cast.hpp> +#include <cstring> + namespace armnn { +namespace +{ + void PadParams(StridedSliceDescriptor& p, unsigned int dimCount) { BOOST_ASSERT_MSG(dimCount <= 4, "Expected input with at most 4 dimensions"); @@ -78,42 +85,37 @@ TensorShape ExtendShape(const TensorShape& inputShape, return TensorShape(newNumDimensions, newSizes); } -template<typename T> +} // Anonymous namespace + void StridedSlice(const TensorInfo& inputInfo, - const TensorInfo& outputInfo, const StridedSliceDescriptor& params, - const T* inputData, - T* outputData) + const void* inputData, + void* outputData, + unsigned int dataTypeSize) { - const TensorShape inputShape = - ExtendShape(inputInfo.GetShape(), 4); + const unsigned char* input = reinterpret_cast<const unsigned char*>(inputData); + unsigned char* output = reinterpret_cast<unsigned char*>(outputData); + + const TensorShape inputShape = ExtendShape(inputInfo.GetShape(), 4); StridedSliceDescriptor paddedParams = params; // Pad parameters to 4 dimensions PadParams(paddedParams, 4); - const int start0 = - paddedParams.GetStartForAxis(inputShape, 0); - const int stop0 = - paddedParams.GetStopForAxis(inputShape, 0, start0); + const int start0 = paddedParams.GetStartForAxis(inputShape, 0); + const int stop0 = paddedParams.GetStopForAxis (inputShape, 0, start0); - const int start1 = - paddedParams.GetStartForAxis(inputShape, 1); - const int stop1 = - paddedParams.GetStopForAxis(inputShape, 1, start1); + const int start1 = paddedParams.GetStartForAxis(inputShape, 1); + const int stop1 = paddedParams.GetStopForAxis (inputShape, 1, start1); - const int start2 = - paddedParams.GetStartForAxis(inputShape, 2); - const int stop2 = - paddedParams.GetStopForAxis(inputShape, 2, start2); + const int start2 = paddedParams.GetStartForAxis(inputShape, 2); + const int stop2 = paddedParams.GetStopForAxis (inputShape, 2, start2); - const int start3 = - paddedParams.GetStartForAxis(inputShape, 3); - const int stop3 = - paddedParams.GetStopForAxis(inputShape, 3, start3); + const int start3 = paddedParams.GetStartForAxis(inputShape, 3); + const int stop3 = paddedParams.GetStopForAxis (inputShape, 3, start3); - T* outPtr = outputData; + const int step = boost::numeric_cast<int>(dataTypeSize); for (int in0 = start0; !LoopCondition(in0, stop0, paddedParams.m_Stride[0]); @@ -135,24 +137,13 @@ void StridedSlice(const TensorInfo& inputInfo, int dim2 = boost::numeric_cast<int>(inputShape[2]); int dim3 = boost::numeric_cast<int>(inputShape[3]); - int inputOffset = ((in0 * dim1 + in1) * dim2 + in2) * dim3 + in3; - *(outPtr++) = inputData[inputOffset]; + int inputOffset = (((in0 * dim1 + in1) * dim2 + in2) * dim3 + in3) * step; + ::memcpy(output, input + inputOffset, dataTypeSize); + output += step; } } } } } -template void StridedSlice<float>(const TensorInfo& inputInfo, - const TensorInfo& outputInfo, - const StridedSliceDescriptor& params, - const float* inputData, - float* outData); - -template void StridedSlice<uint8_t>(const TensorInfo& inputInfo, - const TensorInfo& outputInfo, - const StridedSliceDescriptor& params, - const uint8_t* inputData, - uint8_t* outData); - -} //namespace armnn +} // namespace armnn |