diff options
author | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-05-28 14:31:20 +0100 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-05-31 11:57:50 +0100 |
commit | e851b3da2ba51edc69c7b3dbfad06c4e22a63595 (patch) | |
tree | 5ead856b8c4de5198170f8ff3fdb2541eb6676d9 /src/backends/reference/workloads/StridedSlice.cpp | |
parent | 01961a7df1c4357981a33b9c1eb80fb51888a8fa (diff) | |
download | armnn-e851b3da2ba51edc69c7b3dbfad06c4e22a63595.tar.gz |
IVGCVSW-3170 Refactor the Strided Slice Ref workload for Float32 and
QAsymm8 types
* RefStridedSliceWorkload is no longer a template class
* Refactoring of the ref StridedSlice implementation
* Added ValidateTensorQuantizationSpace function
Change-Id: Ifa182a33d79d42137731f48b995a7973c9d92152
Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
Diffstat (limited to 'src/backends/reference/workloads/StridedSlice.cpp')
-rw-r--r-- | src/backends/reference/workloads/StridedSlice.cpp | 67 |
1 files changed, 29 insertions, 38 deletions
diff --git a/src/backends/reference/workloads/StridedSlice.cpp b/src/backends/reference/workloads/StridedSlice.cpp index 71903e421d..9f2b1e76f6 100644 --- a/src/backends/reference/workloads/StridedSlice.cpp +++ b/src/backends/reference/workloads/StridedSlice.cpp @@ -5,12 +5,19 @@ #include "StridedSlice.hpp" +#include <ResolveType.hpp> + #include <boost/assert.hpp> #include <boost/numeric/conversion/cast.hpp> +#include <cstring> + namespace armnn { +namespace +{ + void PadParams(StridedSliceDescriptor& p, unsigned int dimCount) { BOOST_ASSERT_MSG(dimCount <= 4, "Expected input with at most 4 dimensions"); @@ -78,42 +85,37 @@ TensorShape ExtendShape(const TensorShape& inputShape, return TensorShape(newNumDimensions, newSizes); } -template<typename T> +} // Anonymous namespace + void StridedSlice(const TensorInfo& inputInfo, - const TensorInfo& outputInfo, const StridedSliceDescriptor& params, - const T* inputData, - T* outputData) + const void* inputData, + void* outputData, + unsigned int dataTypeSize) { - const TensorShape inputShape = - ExtendShape(inputInfo.GetShape(), 4); + const unsigned char* input = reinterpret_cast<const unsigned char*>(inputData); + unsigned char* output = reinterpret_cast<unsigned char*>(outputData); + + const TensorShape inputShape = ExtendShape(inputInfo.GetShape(), 4); StridedSliceDescriptor paddedParams = params; // Pad parameters to 4 dimensions PadParams(paddedParams, 4); - const int start0 = - paddedParams.GetStartForAxis(inputShape, 0); - const int stop0 = - paddedParams.GetStopForAxis(inputShape, 0, start0); + const int start0 = paddedParams.GetStartForAxis(inputShape, 0); + const int stop0 = paddedParams.GetStopForAxis (inputShape, 0, start0); - const int start1 = - paddedParams.GetStartForAxis(inputShape, 1); - const int stop1 = - paddedParams.GetStopForAxis(inputShape, 1, start1); + const int start1 = paddedParams.GetStartForAxis(inputShape, 1); + const int stop1 = paddedParams.GetStopForAxis (inputShape, 1, start1); - const int start2 = - paddedParams.GetStartForAxis(inputShape, 2); - const int stop2 = - paddedParams.GetStopForAxis(inputShape, 2, start2); + const int start2 = paddedParams.GetStartForAxis(inputShape, 2); + const int stop2 = paddedParams.GetStopForAxis (inputShape, 2, start2); - const int start3 = - paddedParams.GetStartForAxis(inputShape, 3); - const int stop3 = - paddedParams.GetStopForAxis(inputShape, 3, start3); + const int start3 = paddedParams.GetStartForAxis(inputShape, 3); + const int stop3 = paddedParams.GetStopForAxis (inputShape, 3, start3); - T* outPtr = outputData; + const int step = boost::numeric_cast<int>(dataTypeSize); for (int in0 = start0; !LoopCondition(in0, stop0, paddedParams.m_Stride[0]); @@ -135,24 +137,13 @@ void StridedSlice(const TensorInfo& inputInfo, int dim2 = boost::numeric_cast<int>(inputShape[2]); int dim3 = boost::numeric_cast<int>(inputShape[3]); - int inputOffset = ((in0 * dim1 + in1) * dim2 + in2) * dim3 + in3; - *(outPtr++) = inputData[inputOffset]; + int inputOffset = (((in0 * dim1 + in1) * dim2 + in2) * dim3 + in3) * step; + ::memcpy(output, input + inputOffset, dataTypeSize); + output += step; } } } } } -template void StridedSlice<float>(const TensorInfo& inputInfo, - const TensorInfo& outputInfo, - const StridedSliceDescriptor& params, - const float* inputData, - float* outData); - -template void StridedSlice<uint8_t>(const TensorInfo& inputInfo, - const TensorInfo& outputInfo, - const StridedSliceDescriptor& params, - const uint8_t* inputData, - uint8_t* outData); - -} //namespace armnn +} // namespace armnn |