aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp21
1 files changed, 19 insertions, 2 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index d2ab41ef40..075884b2da 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -149,6 +149,19 @@ void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
}
}
+void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
+{
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
+ if (tensor.GetDataType() != DataType::QSymmS8 &&
+ tensor.GetDataType() != DataType::QuantizedSymm8PerAxis)
+ {
+ throw InvalidArgumentException(descName +
+ ": Expected data type which supports per-axis quantization scheme but got " +
+ GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
+ }
+ ARMNN_NO_DEPRECATE_WARN_END
+}
+
//---------------------------------------------------------------
void ValidateTensorQuantizationSpace(const TensorInfo& first,
const TensorInfo& second,
@@ -344,11 +357,14 @@ void ValidateWeightDataType(const TensorInfo& inputInfo,
const DataType inputType = inputInfo.GetDataType();
if (inputType == DataType::QAsymmU8)
{
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
const std::vector<DataType> validTypes =
{
DataType::QAsymmU8,
- DataType::QuantizedSymm8PerAxis
+ DataType::QSymmS8,
+ DataType::QuantizedSymm8PerAxis // deprecated
};
+ ARMNN_NO_DEPRECATE_WARN_END
ValidateDataTypes(weightInfo, validTypes, descName);
}
@@ -412,7 +428,8 @@ void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
"but data type does not support per-axis quantization.") % descName % "weight"));
}
- ValidateTensorDataType(weightInfo, DataType::QuantizedSymm8PerAxis, descName, "weight");
+
+ ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");