From e851b3da2ba51edc69c7b3dbfad06c4e22a63595 Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Tue, 28 May 2019 14:31:20 +0100 Subject: IVGCVSW-3170 Refactor the Strided Slice Ref workload for Float32 and QAsymm8 types * RefStridedSliceWorkload is no longer a template class * Refactoring of the ref StridedSlice implementation * Added ValidateTensorQuantizationSpace function Change-Id: Ifa182a33d79d42137731f48b995a7973c9d92152 Signed-off-by: Matteo Martincigh --- src/backends/backendsCommon/WorkloadData.cpp | 52 +++++++++++++++++ src/backends/reference/RefWorkloadFactory.cpp | 2 +- .../workloads/RefStridedSliceWorkload.cpp | 36 +++++++----- .../workloads/RefStridedSliceWorkload.hpp | 20 +------ src/backends/reference/workloads/StridedSlice.cpp | 67 ++++++++++------------ src/backends/reference/workloads/StridedSlice.hpp | 9 ++- 6 files changed, 110 insertions(+), 76 deletions(-) diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index c94fa25ac2..c4f1b24d1e 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -124,6 +124,42 @@ void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType, } } +//--------------------------------------------------------------- +void ValidateTensorQuantizationSpace(const TensorInfo& first, + const TensorInfo& second, + const std::string& descName, + std::string const& firstName, + std::string const& secondName) +{ + if (!first.IsQuantized() || + !second.IsQuantized()) + { + // Not a quantized type, ignore the validation + return; + } + + DataType firstDataType = first.GetDataType(); + DataType secondDataType = second.GetDataType(); + + if (firstDataType != secondDataType) + { + throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName + + " must be of the same quantized type, " + + firstName + " is " + GetDataTypeName(firstDataType) + ", " + + secondName + " is " + GetDataTypeName(secondDataType)); + } + + if (!first.IsTypeSpaceMatch(second)) + { + throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName + + " must have the same quantization space, " + + firstName + " has offset " + to_string(first.GetQuantizationOffset()) + + " and scale " + to_string(first.GetQuantizationScale()) + ", " + + secondName + " has offset " + to_string(second.GetQuantizationOffset()) + + " and scale " + to_string(second.GetQuantizationScale())); + } +} + //--------------------------------------------------------------- void ValidateBiasTensorQuantization(const TensorInfo& biasTensor, const TensorInfo& inputTensorInfo, const TensorInfo& weightsTensorInfo, const std::string& descName) @@ -1214,6 +1250,22 @@ void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con ValidateNumOutputs(workloadInfo, "StridedSliceQueueDescriptor", 1); const TensorInfo& input = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0]; + + std::vector supportedTypes = + { + DataType::Float16, + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(input, supportedTypes, "StridedSliceQueueDescriptor"); + ValidateDataTypes(output, supportedTypes, "StridedSliceQueueDescriptor"); + + ValidateDataTypes(output, { input.GetDataType() }, "StridedSliceQueueDescriptor"); + + ValidateTensorQuantizationSpace(input, output, "StridedSliceQueueDescriptor", "input", "output"); + const uint32_t rank = input.GetNumDimensions(); if (rank > 4) diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index bf2339756c..b829d54be0 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -370,7 +370,7 @@ std::unique_ptr RefWorkloadFactory::CreateBatchToSpaceNd(const BatchT std::unique_ptr RefWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload(descriptor, info); + return std::make_unique(descriptor, info); } std::unique_ptr RefWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor, diff --git a/src/backends/reference/workloads/RefStridedSliceWorkload.cpp b/src/backends/reference/workloads/RefStridedSliceWorkload.cpp index bcc3520f45..8bb1670a48 100644 --- a/src/backends/reference/workloads/RefStridedSliceWorkload.cpp +++ b/src/backends/reference/workloads/RefStridedSliceWorkload.cpp @@ -4,31 +4,37 @@ // #include "RefStridedSliceWorkload.hpp" +#include "RefWorkloadUtils.hpp" #include "StridedSlice.hpp" -#include "RefWorkloadUtils.hpp" -#include +#include namespace armnn { -template -void RefStridedSliceWorkload::Execute() const -{ - using T = ResolveType; +RefStridedSliceWorkload::RefStridedSliceWorkload(const StridedSliceQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload(descriptor, info) +{} - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, GetName() + "_Execute"); +void RefStridedSliceWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefStridedSliceWorkload_Execute"); - const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); - const T* inputData = GetInputTensorData(0, m_Data); - T* outputData = GetOutputTensorData(0, m_Data); + DataType inputDataType = inputInfo.GetDataType(); + DataType outputDataType = outputInfo.GetDataType(); - StridedSlice(inputInfo, outputInfo, m_Data.m_Parameters, inputData, outputData); -} + BOOST_ASSERT(inputDataType == outputDataType); + boost::ignore_unused(outputDataType); -template class RefStridedSliceWorkload; -template class RefStridedSliceWorkload; + StridedSlice(inputInfo, + m_Data.m_Parameters, + m_Data.m_Inputs[0]->Map(), + m_Data.m_Outputs[0]->Map(), + GetDataTypeSize(inputDataType)); +} -} //namespace armnn +} // namespace armnn diff --git a/src/backends/reference/workloads/RefStridedSliceWorkload.hpp b/src/backends/reference/workloads/RefStridedSliceWorkload.hpp index b3586adbda..44aabc0106 100644 --- a/src/backends/reference/workloads/RefStridedSliceWorkload.hpp +++ b/src/backends/reference/workloads/RefStridedSliceWorkload.hpp @@ -7,28 +7,14 @@ #include -#include - namespace armnn { -template -class RefStridedSliceWorkload : public TypedWorkload +class RefStridedSliceWorkload : public BaseWorkload { public: - static const std::string& GetName() - { - static const std::string name = std::string("RefStridedSlice") + GetDataTypeName(DataType) + "Workload"; - return name; - } - - using TypedWorkload::m_Data; - using TypedWorkload::TypedWorkload; - + RefStridedSliceWorkload(const StridedSliceQueueDescriptor& descriptor, const WorkloadInfo& info); void Execute() const override; }; -using RefStridedSliceFloat32Workload = RefStridedSliceWorkload; -using RefStridedSliceUint8Workload = RefStridedSliceWorkload; - -} //namespace armnn +} // namespace armnn diff --git a/src/backends/reference/workloads/StridedSlice.cpp b/src/backends/reference/workloads/StridedSlice.cpp index 71903e421d..9f2b1e76f6 100644 --- a/src/backends/reference/workloads/StridedSlice.cpp +++ b/src/backends/reference/workloads/StridedSlice.cpp @@ -5,12 +5,19 @@ #include "StridedSlice.hpp" +#include + #include #include +#include + namespace armnn { +namespace +{ + void PadParams(StridedSliceDescriptor& p, unsigned int dimCount) { BOOST_ASSERT_MSG(dimCount <= 4, "Expected input with at most 4 dimensions"); @@ -78,42 +85,37 @@ TensorShape ExtendShape(const TensorShape& inputShape, return TensorShape(newNumDimensions, newSizes); } -template +} // Anonymous namespace + void StridedSlice(const TensorInfo& inputInfo, - const TensorInfo& outputInfo, const StridedSliceDescriptor& params, - const T* inputData, - T* outputData) + const void* inputData, + void* outputData, + unsigned int dataTypeSize) { - const TensorShape inputShape = - ExtendShape(inputInfo.GetShape(), 4); + const unsigned char* input = reinterpret_cast(inputData); + unsigned char* output = reinterpret_cast(outputData); + + const TensorShape inputShape = ExtendShape(inputInfo.GetShape(), 4); StridedSliceDescriptor paddedParams = params; // Pad parameters to 4 dimensions PadParams(paddedParams, 4); - const int start0 = - paddedParams.GetStartForAxis(inputShape, 0); - const int stop0 = - paddedParams.GetStopForAxis(inputShape, 0, start0); + const int start0 = paddedParams.GetStartForAxis(inputShape, 0); + const int stop0 = paddedParams.GetStopForAxis (inputShape, 0, start0); - const int start1 = - paddedParams.GetStartForAxis(inputShape, 1); - const int stop1 = - paddedParams.GetStopForAxis(inputShape, 1, start1); + const int start1 = paddedParams.GetStartForAxis(inputShape, 1); + const int stop1 = paddedParams.GetStopForAxis (inputShape, 1, start1); - const int start2 = - paddedParams.GetStartForAxis(inputShape, 2); - const int stop2 = - paddedParams.GetStopForAxis(inputShape, 2, start2); + const int start2 = paddedParams.GetStartForAxis(inputShape, 2); + const int stop2 = paddedParams.GetStopForAxis (inputShape, 2, start2); - const int start3 = - paddedParams.GetStartForAxis(inputShape, 3); - const int stop3 = - paddedParams.GetStopForAxis(inputShape, 3, start3); + const int start3 = paddedParams.GetStartForAxis(inputShape, 3); + const int stop3 = paddedParams.GetStopForAxis (inputShape, 3, start3); - T* outPtr = outputData; + const int step = boost::numeric_cast(dataTypeSize); for (int in0 = start0; !LoopCondition(in0, stop0, paddedParams.m_Stride[0]); @@ -135,24 +137,13 @@ void StridedSlice(const TensorInfo& inputInfo, int dim2 = boost::numeric_cast(inputShape[2]); int dim3 = boost::numeric_cast(inputShape[3]); - int inputOffset = ((in0 * dim1 + in1) * dim2 + in2) * dim3 + in3; - *(outPtr++) = inputData[inputOffset]; + int inputOffset = (((in0 * dim1 + in1) * dim2 + in2) * dim3 + in3) * step; + ::memcpy(output, input + inputOffset, dataTypeSize); + output += step; } } } } } -template void StridedSlice(const TensorInfo& inputInfo, - const TensorInfo& outputInfo, - const StridedSliceDescriptor& params, - const float* inputData, - float* outData); - -template void StridedSlice(const TensorInfo& inputInfo, - const TensorInfo& outputInfo, - const StridedSliceDescriptor& params, - const uint8_t* inputData, - uint8_t* outData); - -} //namespace armnn +} // namespace armnn diff --git a/src/backends/reference/workloads/StridedSlice.hpp b/src/backends/reference/workloads/StridedSlice.hpp index 8eed8706dc..b13a8e4e33 100644 --- a/src/backends/reference/workloads/StridedSlice.hpp +++ b/src/backends/reference/workloads/StridedSlice.hpp @@ -11,11 +11,10 @@ namespace armnn { -template void StridedSlice(const TensorInfo& inputInfo, - const TensorInfo& outputInfo, const StridedSliceDescriptor& params, - const T* inputData, - T* outputData); + const void* inputData, + void* outputData, + unsigned int dataTypeSize); -} //namespace armnn +} // namespace armnn -- cgit v1.2.1