aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp28
1 files changed, 17 insertions, 11 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index ebaf961fe8..fea72256a1 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -30,6 +30,8 @@ DataType GetBiasDataType(DataType inputDataType)
return DataType::Float16;
case DataType::Float32:
return DataType::Float32;
+ case DataType::QAsymmS8:
+ return DataType::Signed32;
case DataType::QAsymmU8:
return DataType::Signed32;
case DataType::QSymmS8:
@@ -357,12 +359,13 @@ void ValidateWeightDataType(const TensorInfo& inputInfo,
const std::string& descName)
{
const DataType inputType = inputInfo.GetDataType();
- if (inputType == DataType::QAsymmU8)
+ if (IsQuantized8BitType(inputType))
{
ARMNN_NO_DEPRECATE_WARN_BEGIN
const std::vector<DataType> validTypes =
{
DataType::QAsymmU8,
+ DataType::QAsymmS8,
DataType::QSymmS8,
DataType::QuantizedSymm8PerAxis // deprecated
};
@@ -420,8 +423,7 @@ void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
const DataType inputDataType = inputInfo.GetDataType();
const DataType outputDataType = outputInfo.GetDataType();
- const bool canHavePerAxisQuantization = (inputDataType == DataType::QSymmS8 ||
- inputDataType == DataType::QAsymmU8) && inputDataType == outputDataType;
+ const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
if (!canHavePerAxisQuantization)
{
@@ -599,6 +601,7 @@ void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
DataType::Float16,
DataType::Float32,
+ DataType::QAsymmS8,
DataType::QAsymmU8,
DataType::QSymmS16
};
@@ -684,6 +687,7 @@ void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
DataType::Float16,
DataType::Float32,
+ DataType::QAsymmS8,
DataType::QAsymmU8,
DataType::QSymmS16
};
@@ -1038,10 +1042,11 @@ void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
DataType::Float32,
+ DataType::Float16,
+ DataType::QAsymmS8,
DataType::QAsymmU8,
DataType::QSymmS16,
- DataType::QSymmS8,
- DataType::Float16
+ DataType::QSymmS8
};
ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
@@ -1181,6 +1186,7 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co
{
DataType::Float32,
DataType::QAsymmU8,
+ DataType::QAsymmS8,
DataType::QSymmS16,
DataType::QSymmS8,
DataType::Float16
@@ -1255,6 +1261,7 @@ void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa
{
DataType::Float32,
DataType::QAsymmU8,
+ DataType::QAsymmS8,
DataType::QSymmS16,
DataType::Float16
};
@@ -1309,6 +1316,7 @@ void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
DataType::Float32,
DataType::Float16,
+ DataType::QAsymmS8,
DataType::QAsymmU8,
DataType::QSymmS16
};
@@ -1560,9 +1568,10 @@ void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
DataType::Float32,
DataType::Float16,
DataType::Signed32,
+ DataType::QSymmS16,
+ DataType::QAsymmS8,
DataType::QAsymmU8,
- DataType::QSymmS8,
- DataType::QSymmS16
+ DataType::QSymmS8
};
ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
@@ -2208,10 +2217,7 @@ void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
- if (outputTensorInfo.GetDataType() != DataType::QAsymmS8 &&
- outputTensorInfo.GetDataType() != DataType::QAsymmU8 &&
- outputTensorInfo.GetDataType() != DataType::QSymmS8 &&
- outputTensorInfo.GetDataType() != DataType::QSymmS16)
+ if (!IsQuantizedType(outputTensorInfo.GetDataType()))
{
throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
}