From e851b3da2ba51edc69c7b3dbfad06c4e22a63595 Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Tue, 28 May 2019 14:31:20 +0100 Subject: IVGCVSW-3170 Refactor the Strided Slice Ref workload for Float32 and QAsymm8 types * RefStridedSliceWorkload is no longer a template class * Refactoring of the ref StridedSlice implementation * Added ValidateTensorQuantizationSpace function Change-Id: Ifa182a33d79d42137731f48b995a7973c9d92152 Signed-off-by: Matteo Martincigh --- src/backends/reference/workloads/StridedSlice.cpp | 67 ++++++++++------------- 1 file changed, 29 insertions(+), 38 deletions(-) (limited to 'src/backends/reference/workloads/StridedSlice.cpp') diff --git a/src/backends/reference/workloads/StridedSlice.cpp b/src/backends/reference/workloads/StridedSlice.cpp index 71903e421d..9f2b1e76f6 100644 --- a/src/backends/reference/workloads/StridedSlice.cpp +++ b/src/backends/reference/workloads/StridedSlice.cpp @@ -5,12 +5,19 @@ #include "StridedSlice.hpp" +#include + #include #include +#include + namespace armnn { +namespace +{ + void PadParams(StridedSliceDescriptor& p, unsigned int dimCount) { BOOST_ASSERT_MSG(dimCount <= 4, "Expected input with at most 4 dimensions"); @@ -78,42 +85,37 @@ TensorShape ExtendShape(const TensorShape& inputShape, return TensorShape(newNumDimensions, newSizes); } -template +} // Anonymous namespace + void StridedSlice(const TensorInfo& inputInfo, - const TensorInfo& outputInfo, const StridedSliceDescriptor& params, - const T* inputData, - T* outputData) + const void* inputData, + void* outputData, + unsigned int dataTypeSize) { - const TensorShape inputShape = - ExtendShape(inputInfo.GetShape(), 4); + const unsigned char* input = reinterpret_cast(inputData); + unsigned char* output = reinterpret_cast(outputData); + + const TensorShape inputShape = ExtendShape(inputInfo.GetShape(), 4); StridedSliceDescriptor paddedParams = params; // Pad parameters to 4 dimensions PadParams(paddedParams, 4); - const int start0 = - paddedParams.GetStartForAxis(inputShape, 0); - const int stop0 = - paddedParams.GetStopForAxis(inputShape, 0, start0); + const int start0 = paddedParams.GetStartForAxis(inputShape, 0); + const int stop0 = paddedParams.GetStopForAxis (inputShape, 0, start0); - const int start1 = - paddedParams.GetStartForAxis(inputShape, 1); - const int stop1 = - paddedParams.GetStopForAxis(inputShape, 1, start1); + const int start1 = paddedParams.GetStartForAxis(inputShape, 1); + const int stop1 = paddedParams.GetStopForAxis (inputShape, 1, start1); - const int start2 = - paddedParams.GetStartForAxis(inputShape, 2); - const int stop2 = - paddedParams.GetStopForAxis(inputShape, 2, start2); + const int start2 = paddedParams.GetStartForAxis(inputShape, 2); + const int stop2 = paddedParams.GetStopForAxis (inputShape, 2, start2); - const int start3 = - paddedParams.GetStartForAxis(inputShape, 3); - const int stop3 = - paddedParams.GetStopForAxis(inputShape, 3, start3); + const int start3 = paddedParams.GetStartForAxis(inputShape, 3); + const int stop3 = paddedParams.GetStopForAxis (inputShape, 3, start3); - T* outPtr = outputData; + const int step = boost::numeric_cast(dataTypeSize); for (int in0 = start0; !LoopCondition(in0, stop0, paddedParams.m_Stride[0]); @@ -135,24 +137,13 @@ void StridedSlice(const TensorInfo& inputInfo, int dim2 = boost::numeric_cast(inputShape[2]); int dim3 = boost::numeric_cast(inputShape[3]); - int inputOffset = ((in0 * dim1 + in1) * dim2 + in2) * dim3 + in3; - *(outPtr++) = inputData[inputOffset]; + int inputOffset = (((in0 * dim1 + in1) * dim2 + in2) * dim3 + in3) * step; + ::memcpy(output, input + inputOffset, dataTypeSize); + output += step; } } } } } -template void StridedSlice(const TensorInfo& inputInfo, - const TensorInfo& outputInfo, - const StridedSliceDescriptor& params, - const float* inputData, - float* outData); - -template void StridedSlice(const TensorInfo& inputInfo, - const TensorInfo& outputInfo, - const StridedSliceDescriptor& params, - const uint8_t* inputData, - uint8_t* outData); - -} //namespace armnn +} // namespace armnn -- cgit v1.2.1