aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp52
1 files changed, 52 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index c94fa25ac2..c4f1b24d1e 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -125,6 +125,42 @@ void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
}
//---------------------------------------------------------------
+void ValidateTensorQuantizationSpace(const TensorInfo& first,
+ const TensorInfo& second,
+ const std::string& descName,
+ std::string const& firstName,
+ std::string const& secondName)
+{
+ if (!first.IsQuantized() ||
+ !second.IsQuantized())
+ {
+ // Not a quantized type, ignore the validation
+ return;
+ }
+
+ DataType firstDataType = first.GetDataType();
+ DataType secondDataType = second.GetDataType();
+
+ if (firstDataType != secondDataType)
+ {
+ throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
+ " must be of the same quantized type, " +
+ firstName + " is " + GetDataTypeName(firstDataType) + ", " +
+ secondName + " is " + GetDataTypeName(secondDataType));
+ }
+
+ if (!first.IsTypeSpaceMatch(second))
+ {
+ throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
+ " must have the same quantization space, " +
+ firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
+ " and scale " + to_string(first.GetQuantizationScale()) + ", " +
+ secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
+ " and scale " + to_string(second.GetQuantizationScale()));
+ }
+}
+
+//---------------------------------------------------------------
void ValidateBiasTensorQuantization(const TensorInfo& biasTensor, const TensorInfo& inputTensorInfo,
const TensorInfo& weightsTensorInfo, const std::string& descName)
{
@@ -1214,6 +1250,22 @@ void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con
ValidateNumOutputs(workloadInfo, "StridedSliceQueueDescriptor", 1);
const TensorInfo& input = workloadInfo.m_InputTensorInfos[0];
+ const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0];
+
+ std::vector<DataType> supportedTypes =
+ {
+ DataType::Float16,
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(input, supportedTypes, "StridedSliceQueueDescriptor");
+ ValidateDataTypes(output, supportedTypes, "StridedSliceQueueDescriptor");
+
+ ValidateDataTypes(output, { input.GetDataType() }, "StridedSliceQueueDescriptor");
+
+ ValidateTensorQuantizationSpace(input, output, "StridedSliceQueueDescriptor", "input", "output");
+
const uint32_t rank = input.GetNumDimensions();
if (rank > 4)