aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2019-05-28 14:31:20 +0100
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-05-31 11:57:50 +0100
commite851b3da2ba51edc69c7b3dbfad06c4e22a63595 (patch)
tree5ead856b8c4de5198170f8ff3fdb2541eb6676d9
parent01961a7df1c4357981a33b9c1eb80fb51888a8fa (diff)
downloadarmnn-e851b3da2ba51edc69c7b3dbfad06c4e22a63595.tar.gz
IVGCVSW-3170 Refactor the Strided Slice Ref workload for Float32 and
QAsymm8 types * RefStridedSliceWorkload is no longer a template class * Refactoring of the ref StridedSlice implementation * Added ValidateTensorQuantizationSpace function Change-Id: Ifa182a33d79d42137731f48b995a7973c9d92152 Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp52
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp2
-rw-r--r--src/backends/reference/workloads/RefStridedSliceWorkload.cpp36
-rw-r--r--src/backends/reference/workloads/RefStridedSliceWorkload.hpp20
-rw-r--r--src/backends/reference/workloads/StridedSlice.cpp67
-rw-r--r--src/backends/reference/workloads/StridedSlice.hpp9
6 files changed, 110 insertions, 76 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index c94fa25ac2..c4f1b24d1e 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -125,6 +125,42 @@ void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
}
//---------------------------------------------------------------
+void ValidateTensorQuantizationSpace(const TensorInfo& first,
+ const TensorInfo& second,
+ const std::string& descName,
+ std::string const& firstName,
+ std::string const& secondName)
+{
+ if (!first.IsQuantized() ||
+ !second.IsQuantized())
+ {
+ // Not a quantized type, ignore the validation
+ return;
+ }
+
+ DataType firstDataType = first.GetDataType();
+ DataType secondDataType = second.GetDataType();
+
+ if (firstDataType != secondDataType)
+ {
+ throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
+ " must be of the same quantized type, " +
+ firstName + " is " + GetDataTypeName(firstDataType) + ", " +
+ secondName + " is " + GetDataTypeName(secondDataType));
+ }
+
+ if (!first.IsTypeSpaceMatch(second))
+ {
+ throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
+ " must have the same quantization space, " +
+ firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
+ " and scale " + to_string(first.GetQuantizationScale()) + ", " +
+ secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
+ " and scale " + to_string(second.GetQuantizationScale()));
+ }
+}
+
+//---------------------------------------------------------------
void ValidateBiasTensorQuantization(const TensorInfo& biasTensor, const TensorInfo& inputTensorInfo,
const TensorInfo& weightsTensorInfo, const std::string& descName)
{
@@ -1214,6 +1250,22 @@ void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con
ValidateNumOutputs(workloadInfo, "StridedSliceQueueDescriptor", 1);
const TensorInfo& input = workloadInfo.m_InputTensorInfos[0];
+ const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0];
+
+ std::vector<DataType> supportedTypes =
+ {
+ DataType::Float16,
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(input, supportedTypes, "StridedSliceQueueDescriptor");
+ ValidateDataTypes(output, supportedTypes, "StridedSliceQueueDescriptor");
+
+ ValidateDataTypes(output, { input.GetDataType() }, "StridedSliceQueueDescriptor");
+
+ ValidateTensorQuantizationSpace(input, output, "StridedSliceQueueDescriptor", "input", "output");
+
const uint32_t rank = input.GetNumDimensions();
if (rank > 4)
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index bf2339756c..b829d54be0 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -370,7 +370,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateBatchToSpaceNd(const BatchT
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkload<RefStridedSliceFloat32Workload, RefStridedSliceUint8Workload>(descriptor, info);
+ return std::make_unique<RefStridedSliceWorkload>(descriptor, info);
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
diff --git a/src/backends/reference/workloads/RefStridedSliceWorkload.cpp b/src/backends/reference/workloads/RefStridedSliceWorkload.cpp
index bcc3520f45..8bb1670a48 100644
--- a/src/backends/reference/workloads/RefStridedSliceWorkload.cpp
+++ b/src/backends/reference/workloads/RefStridedSliceWorkload.cpp
@@ -4,31 +4,37 @@
//
#include "RefStridedSliceWorkload.hpp"
+#include "RefWorkloadUtils.hpp"
#include "StridedSlice.hpp"
-#include "RefWorkloadUtils.hpp"
-#include <ResolveType.hpp>
+#include <boost/format.hpp>
namespace armnn
{
-template<armnn::DataType DataType>
-void RefStridedSliceWorkload<DataType>::Execute() const
-{
- using T = ResolveType<DataType>;
+RefStridedSliceWorkload::RefStridedSliceWorkload(const StridedSliceQueueDescriptor& descriptor,
+ const WorkloadInfo& info)
+ : BaseWorkload(descriptor, info)
+{}
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, GetName() + "_Execute");
+void RefStridedSliceWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefStridedSliceWorkload_Execute");
- const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
- const T* inputData = GetInputTensorData<T>(0, m_Data);
- T* outputData = GetOutputTensorData<T>(0, m_Data);
+ DataType inputDataType = inputInfo.GetDataType();
+ DataType outputDataType = outputInfo.GetDataType();
- StridedSlice(inputInfo, outputInfo, m_Data.m_Parameters, inputData, outputData);
-}
+ BOOST_ASSERT(inputDataType == outputDataType);
+ boost::ignore_unused(outputDataType);
-template class RefStridedSliceWorkload<DataType::Float32>;
-template class RefStridedSliceWorkload<DataType::QuantisedAsymm8>;
+ StridedSlice(inputInfo,
+ m_Data.m_Parameters,
+ m_Data.m_Inputs[0]->Map(),
+ m_Data.m_Outputs[0]->Map(),
+ GetDataTypeSize(inputDataType));
+}
-} //namespace armnn
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefStridedSliceWorkload.hpp b/src/backends/reference/workloads/RefStridedSliceWorkload.hpp
index b3586adbda..44aabc0106 100644
--- a/src/backends/reference/workloads/RefStridedSliceWorkload.hpp
+++ b/src/backends/reference/workloads/RefStridedSliceWorkload.hpp
@@ -7,28 +7,14 @@
#include <backendsCommon/Workload.hpp>
-#include <armnn/TypesUtils.hpp>
-
namespace armnn
{
-template <armnn::DataType DataType>
-class RefStridedSliceWorkload : public TypedWorkload<StridedSliceQueueDescriptor, DataType>
+class RefStridedSliceWorkload : public BaseWorkload<StridedSliceQueueDescriptor>
{
public:
- static const std::string& GetName()
- {
- static const std::string name = std::string("RefStridedSlice") + GetDataTypeName(DataType) + "Workload";
- return name;
- }
-
- using TypedWorkload<StridedSliceQueueDescriptor, DataType>::m_Data;
- using TypedWorkload<StridedSliceQueueDescriptor, DataType>::TypedWorkload;
-
+ RefStridedSliceWorkload(const StridedSliceQueueDescriptor& descriptor, const WorkloadInfo& info);
void Execute() const override;
};
-using RefStridedSliceFloat32Workload = RefStridedSliceWorkload<DataType::Float32>;
-using RefStridedSliceUint8Workload = RefStridedSliceWorkload<DataType::QuantisedAsymm8>;
-
-} //namespace armnn
+} // namespace armnn
diff --git a/src/backends/reference/workloads/StridedSlice.cpp b/src/backends/reference/workloads/StridedSlice.cpp
index 71903e421d..9f2b1e76f6 100644
--- a/src/backends/reference/workloads/StridedSlice.cpp
+++ b/src/backends/reference/workloads/StridedSlice.cpp
@@ -5,12 +5,19 @@
#include "StridedSlice.hpp"
+#include <ResolveType.hpp>
+
#include <boost/assert.hpp>
#include <boost/numeric/conversion/cast.hpp>
+#include <cstring>
+
namespace armnn
{
+namespace
+{
+
void PadParams(StridedSliceDescriptor& p, unsigned int dimCount)
{
BOOST_ASSERT_MSG(dimCount <= 4, "Expected input with at most 4 dimensions");
@@ -78,42 +85,37 @@ TensorShape ExtendShape(const TensorShape& inputShape,
return TensorShape(newNumDimensions, newSizes);
}
-template<typename T>
+} // Anonymous namespace
+
void StridedSlice(const TensorInfo& inputInfo,
- const TensorInfo& outputInfo,
const StridedSliceDescriptor& params,
- const T* inputData,
- T* outputData)
+ const void* inputData,
+ void* outputData,
+ unsigned int dataTypeSize)
{
- const TensorShape inputShape =
- ExtendShape(inputInfo.GetShape(), 4);
+ const unsigned char* input = reinterpret_cast<const unsigned char*>(inputData);
+ unsigned char* output = reinterpret_cast<unsigned char*>(outputData);
+
+ const TensorShape inputShape = ExtendShape(inputInfo.GetShape(), 4);
StridedSliceDescriptor paddedParams = params;
// Pad parameters to 4 dimensions
PadParams(paddedParams, 4);
- const int start0 =
- paddedParams.GetStartForAxis(inputShape, 0);
- const int stop0 =
- paddedParams.GetStopForAxis(inputShape, 0, start0);
+ const int start0 = paddedParams.GetStartForAxis(inputShape, 0);
+ const int stop0 = paddedParams.GetStopForAxis (inputShape, 0, start0);
- const int start1 =
- paddedParams.GetStartForAxis(inputShape, 1);
- const int stop1 =
- paddedParams.GetStopForAxis(inputShape, 1, start1);
+ const int start1 = paddedParams.GetStartForAxis(inputShape, 1);
+ const int stop1 = paddedParams.GetStopForAxis (inputShape, 1, start1);
- const int start2 =
- paddedParams.GetStartForAxis(inputShape, 2);
- const int stop2 =
- paddedParams.GetStopForAxis(inputShape, 2, start2);
+ const int start2 = paddedParams.GetStartForAxis(inputShape, 2);
+ const int stop2 = paddedParams.GetStopForAxis (inputShape, 2, start2);
- const int start3 =
- paddedParams.GetStartForAxis(inputShape, 3);
- const int stop3 =
- paddedParams.GetStopForAxis(inputShape, 3, start3);
+ const int start3 = paddedParams.GetStartForAxis(inputShape, 3);
+ const int stop3 = paddedParams.GetStopForAxis (inputShape, 3, start3);
- T* outPtr = outputData;
+ const int step = boost::numeric_cast<int>(dataTypeSize);
for (int in0 = start0;
!LoopCondition(in0, stop0, paddedParams.m_Stride[0]);
@@ -135,24 +137,13 @@ void StridedSlice(const TensorInfo& inputInfo,
int dim2 = boost::numeric_cast<int>(inputShape[2]);
int dim3 = boost::numeric_cast<int>(inputShape[3]);
- int inputOffset = ((in0 * dim1 + in1) * dim2 + in2) * dim3 + in3;
- *(outPtr++) = inputData[inputOffset];
+ int inputOffset = (((in0 * dim1 + in1) * dim2 + in2) * dim3 + in3) * step;
+ ::memcpy(output, input + inputOffset, dataTypeSize);
+ output += step;
}
}
}
}
}
-template void StridedSlice<float>(const TensorInfo& inputInfo,
- const TensorInfo& outputInfo,
- const StridedSliceDescriptor& params,
- const float* inputData,
- float* outData);
-
-template void StridedSlice<uint8_t>(const TensorInfo& inputInfo,
- const TensorInfo& outputInfo,
- const StridedSliceDescriptor& params,
- const uint8_t* inputData,
- uint8_t* outData);
-
-} //namespace armnn
+} // namespace armnn
diff --git a/src/backends/reference/workloads/StridedSlice.hpp b/src/backends/reference/workloads/StridedSlice.hpp
index 8eed8706dc..b13a8e4e33 100644
--- a/src/backends/reference/workloads/StridedSlice.hpp
+++ b/src/backends/reference/workloads/StridedSlice.hpp
@@ -11,11 +11,10 @@
namespace armnn
{
-template <typename T>
void StridedSlice(const TensorInfo& inputInfo,
- const TensorInfo& outputInfo,
const StridedSliceDescriptor& params,
- const T* inputData,
- T* outputData);
+ const void* inputData,
+ void* outputData,
+ unsigned int dataTypeSize);
-} //namespace armnn
+} // namespace armnn