aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/StridedSlice.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/StridedSlice.cpp')
-rw-r--r--src/backends/reference/workloads/StridedSlice.cpp67
1 files changed, 29 insertions, 38 deletions
diff --git a/src/backends/reference/workloads/StridedSlice.cpp b/src/backends/reference/workloads/StridedSlice.cpp
index 71903e421d..9f2b1e76f6 100644
--- a/src/backends/reference/workloads/StridedSlice.cpp
+++ b/src/backends/reference/workloads/StridedSlice.cpp
@@ -5,12 +5,19 @@
#include "StridedSlice.hpp"
+#include <ResolveType.hpp>
+
#include <boost/assert.hpp>
#include <boost/numeric/conversion/cast.hpp>
+#include <cstring>
+
namespace armnn
{
+namespace
+{
+
void PadParams(StridedSliceDescriptor& p, unsigned int dimCount)
{
BOOST_ASSERT_MSG(dimCount <= 4, "Expected input with at most 4 dimensions");
@@ -78,42 +85,37 @@ TensorShape ExtendShape(const TensorShape& inputShape,
return TensorShape(newNumDimensions, newSizes);
}
-template<typename T>
+} // Anonymous namespace
+
void StridedSlice(const TensorInfo& inputInfo,
- const TensorInfo& outputInfo,
const StridedSliceDescriptor& params,
- const T* inputData,
- T* outputData)
+ const void* inputData,
+ void* outputData,
+ unsigned int dataTypeSize)
{
- const TensorShape inputShape =
- ExtendShape(inputInfo.GetShape(), 4);
+ const unsigned char* input = reinterpret_cast<const unsigned char*>(inputData);
+ unsigned char* output = reinterpret_cast<unsigned char*>(outputData);
+
+ const TensorShape inputShape = ExtendShape(inputInfo.GetShape(), 4);
StridedSliceDescriptor paddedParams = params;
// Pad parameters to 4 dimensions
PadParams(paddedParams, 4);
- const int start0 =
- paddedParams.GetStartForAxis(inputShape, 0);
- const int stop0 =
- paddedParams.GetStopForAxis(inputShape, 0, start0);
+ const int start0 = paddedParams.GetStartForAxis(inputShape, 0);
+ const int stop0 = paddedParams.GetStopForAxis (inputShape, 0, start0);
- const int start1 =
- paddedParams.GetStartForAxis(inputShape, 1);
- const int stop1 =
- paddedParams.GetStopForAxis(inputShape, 1, start1);
+ const int start1 = paddedParams.GetStartForAxis(inputShape, 1);
+ const int stop1 = paddedParams.GetStopForAxis (inputShape, 1, start1);
- const int start2 =
- paddedParams.GetStartForAxis(inputShape, 2);
- const int stop2 =
- paddedParams.GetStopForAxis(inputShape, 2, start2);
+ const int start2 = paddedParams.GetStartForAxis(inputShape, 2);
+ const int stop2 = paddedParams.GetStopForAxis (inputShape, 2, start2);
- const int start3 =
- paddedParams.GetStartForAxis(inputShape, 3);
- const int stop3 =
- paddedParams.GetStopForAxis(inputShape, 3, start3);
+ const int start3 = paddedParams.GetStartForAxis(inputShape, 3);
+ const int stop3 = paddedParams.GetStopForAxis (inputShape, 3, start3);
- T* outPtr = outputData;
+ const int step = boost::numeric_cast<int>(dataTypeSize);
for (int in0 = start0;
!LoopCondition(in0, stop0, paddedParams.m_Stride[0]);
@@ -135,24 +137,13 @@ void StridedSlice(const TensorInfo& inputInfo,
int dim2 = boost::numeric_cast<int>(inputShape[2]);
int dim3 = boost::numeric_cast<int>(inputShape[3]);
- int inputOffset = ((in0 * dim1 + in1) * dim2 + in2) * dim3 + in3;
- *(outPtr++) = inputData[inputOffset];
+ int inputOffset = (((in0 * dim1 + in1) * dim2 + in2) * dim3 + in3) * step;
+ ::memcpy(output, input + inputOffset, dataTypeSize);
+ output += step;
}
}
}
}
}
-template void StridedSlice<float>(const TensorInfo& inputInfo,
- const TensorInfo& outputInfo,
- const StridedSliceDescriptor& params,
- const float* inputData,
- float* outData);
-
-template void StridedSlice<uint8_t>(const TensorInfo& inputInfo,
- const TensorInfo& outputInfo,
- const StridedSliceDescriptor& params,
- const uint8_t* inputData,
- uint8_t* outData);
-
-} //namespace armnn
+} // namespace armnn