diff options
author | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-05-28 14:31:20 +0100 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-05-31 11:57:50 +0100 |
commit | e851b3da2ba51edc69c7b3dbfad06c4e22a63595 (patch) | |
tree | 5ead856b8c4de5198170f8ff3fdb2541eb6676d9 /src/backends/backendsCommon/WorkloadData.cpp | |
parent | 01961a7df1c4357981a33b9c1eb80fb51888a8fa (diff) | |
download | armnn-e851b3da2ba51edc69c7b3dbfad06c4e22a63595.tar.gz |
IVGCVSW-3170 Refactor the Strided Slice Ref workload for Float32 and
QAsymm8 types
* RefStridedSliceWorkload is no longer a template class
* Refactoring of the ref StridedSlice implementation
* Added ValidateTensorQuantizationSpace function
Change-Id: Ifa182a33d79d42137731f48b995a7973c9d92152
Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index c94fa25ac2..c4f1b24d1e 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -125,6 +125,42 @@ void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType, } //--------------------------------------------------------------- +void ValidateTensorQuantizationSpace(const TensorInfo& first, + const TensorInfo& second, + const std::string& descName, + std::string const& firstName, + std::string const& secondName) +{ + if (!first.IsQuantized() || + !second.IsQuantized()) + { + // Not a quantized type, ignore the validation + return; + } + + DataType firstDataType = first.GetDataType(); + DataType secondDataType = second.GetDataType(); + + if (firstDataType != secondDataType) + { + throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName + + " must be of the same quantized type, " + + firstName + " is " + GetDataTypeName(firstDataType) + ", " + + secondName + " is " + GetDataTypeName(secondDataType)); + } + + if (!first.IsTypeSpaceMatch(second)) + { + throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName + + " must have the same quantization space, " + + firstName + " has offset " + to_string(first.GetQuantizationOffset()) + + " and scale " + to_string(first.GetQuantizationScale()) + ", " + + secondName + " has offset " + to_string(second.GetQuantizationOffset()) + + " and scale " + to_string(second.GetQuantizationScale())); + } +} + +//--------------------------------------------------------------- void ValidateBiasTensorQuantization(const TensorInfo& biasTensor, const TensorInfo& inputTensorInfo, const TensorInfo& weightsTensorInfo, const std::string& descName) { @@ -1214,6 +1250,22 @@ void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con ValidateNumOutputs(workloadInfo, "StridedSliceQueueDescriptor", 1); const TensorInfo& input = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0]; + + std::vector<DataType> supportedTypes = + { + DataType::Float16, + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(input, supportedTypes, "StridedSliceQueueDescriptor"); + ValidateDataTypes(output, supportedTypes, "StridedSliceQueueDescriptor"); + + ValidateDataTypes(output, { input.GetDataType() }, "StridedSliceQueueDescriptor"); + + ValidateTensorQuantizationSpace(input, output, "StridedSliceQueueDescriptor", "input", "output"); + const uint32_t rank = input.GetNumDimensions(); if (rank > 4) |