diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index c94fa25ac2..c4f1b24d1e 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -125,6 +125,42 @@ void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType, } //--------------------------------------------------------------- +void ValidateTensorQuantizationSpace(const TensorInfo& first, + const TensorInfo& second, + const std::string& descName, + std::string const& firstName, + std::string const& secondName) +{ + if (!first.IsQuantized() || + !second.IsQuantized()) + { + // Not a quantized type, ignore the validation + return; + } + + DataType firstDataType = first.GetDataType(); + DataType secondDataType = second.GetDataType(); + + if (firstDataType != secondDataType) + { + throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName + + " must be of the same quantized type, " + + firstName + " is " + GetDataTypeName(firstDataType) + ", " + + secondName + " is " + GetDataTypeName(secondDataType)); + } + + if (!first.IsTypeSpaceMatch(second)) + { + throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName + + " must have the same quantization space, " + + firstName + " has offset " + to_string(first.GetQuantizationOffset()) + + " and scale " + to_string(first.GetQuantizationScale()) + ", " + + secondName + " has offset " + to_string(second.GetQuantizationOffset()) + + " and scale " + to_string(second.GetQuantizationScale())); + } +} + +//--------------------------------------------------------------- void ValidateBiasTensorQuantization(const TensorInfo& biasTensor, const TensorInfo& inputTensorInfo, const TensorInfo& weightsTensorInfo, const std::string& descName) { @@ -1214,6 +1250,22 @@ void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con ValidateNumOutputs(workloadInfo, "StridedSliceQueueDescriptor", 1); const TensorInfo& input = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0]; + + std::vector<DataType> supportedTypes = + { + DataType::Float16, + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(input, supportedTypes, "StridedSliceQueueDescriptor"); + ValidateDataTypes(output, supportedTypes, "StridedSliceQueueDescriptor"); + + ValidateDataTypes(output, { input.GetDataType() }, "StridedSliceQueueDescriptor"); + + ValidateTensorQuantizationSpace(input, output, "StridedSliceQueueDescriptor", "input", "output"); + const uint32_t rank = input.GetNumDimensions(); if (rank > 4) |