aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
authorKeith Davis <keith.davis@arm.com>2020-02-11 16:51:50 +0000
committerJames Conroy <james.conroy@arm.com>2020-02-17 21:53:29 +0000
commit0c2eeac6347533a1d3d456aebea492f5123388f3 (patch)
treef218fc236137791c491b680dfd24fb9706c171a6 /src/backends/backendsCommon/WorkloadData.cpp
parent4c3c1f486ab775eacb1f6455f8468f9be2c3e4f7 (diff)
downloadarmnn-0c2eeac6347533a1d3d456aebea492f5123388f3.tar.gz
IVGCVSW-4436 Add ExecuteNetwork test for mobilenet_v2_int8
* Add QAsymmS8 to QueueDescriptor supportedTypes * Add QSymmS8/QAsymmS8 to RefLayerSupport supportedTypes * Some additional comments and refactoring Change-Id: I8567314452e6e8f6f69cb6e458ee147d3fc92fab Signed-off-by: Keith Davis <keith.davis@arm.com>
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp28
1 files changed, 17 insertions, 11 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index ebaf961fe8..fea72256a1 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -30,6 +30,8 @@ DataType GetBiasDataType(DataType inputDataType)
return DataType::Float16;
case DataType::Float32:
return DataType::Float32;
+ case DataType::QAsymmS8:
+ return DataType::Signed32;
case DataType::QAsymmU8:
return DataType::Signed32;
case DataType::QSymmS8:
@@ -357,12 +359,13 @@ void ValidateWeightDataType(const TensorInfo& inputInfo,
const std::string& descName)
{
const DataType inputType = inputInfo.GetDataType();
- if (inputType == DataType::QAsymmU8)
+ if (IsQuantized8BitType(inputType))
{
ARMNN_NO_DEPRECATE_WARN_BEGIN
const std::vector<DataType> validTypes =
{
DataType::QAsymmU8,
+ DataType::QAsymmS8,
DataType::QSymmS8,
DataType::QuantizedSymm8PerAxis // deprecated
};
@@ -420,8 +423,7 @@ void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
const DataType inputDataType = inputInfo.GetDataType();
const DataType outputDataType = outputInfo.GetDataType();
- const bool canHavePerAxisQuantization = (inputDataType == DataType::QSymmS8 ||
- inputDataType == DataType::QAsymmU8) && inputDataType == outputDataType;
+ const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
if (!canHavePerAxisQuantization)
{
@@ -599,6 +601,7 @@ void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
DataType::Float16,
DataType::Float32,
+ DataType::QAsymmS8,
DataType::QAsymmU8,
DataType::QSymmS16
};
@@ -684,6 +687,7 @@ void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
DataType::Float16,
DataType::Float32,
+ DataType::QAsymmS8,
DataType::QAsymmU8,
DataType::QSymmS16
};
@@ -1038,10 +1042,11 @@ void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
DataType::Float32,
+ DataType::Float16,
+ DataType::QAsymmS8,
DataType::QAsymmU8,
DataType::QSymmS16,
- DataType::QSymmS8,
- DataType::Float16
+ DataType::QSymmS8
};
ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
@@ -1181,6 +1186,7 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co
{
DataType::Float32,
DataType::QAsymmU8,
+ DataType::QAsymmS8,
DataType::QSymmS16,
DataType::QSymmS8,
DataType::Float16
@@ -1255,6 +1261,7 @@ void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa
{
DataType::Float32,
DataType::QAsymmU8,
+ DataType::QAsymmS8,
DataType::QSymmS16,
DataType::Float16
};
@@ -1309,6 +1316,7 @@ void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
DataType::Float32,
DataType::Float16,
+ DataType::QAsymmS8,
DataType::QAsymmU8,
DataType::QSymmS16
};
@@ -1560,9 +1568,10 @@ void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
DataType::Float32,
DataType::Float16,
DataType::Signed32,
+ DataType::QSymmS16,
+ DataType::QAsymmS8,
DataType::QAsymmU8,
- DataType::QSymmS8,
- DataType::QSymmS16
+ DataType::QSymmS8
};
ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
@@ -2208,10 +2217,7 @@ void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
- if (outputTensorInfo.GetDataType() != DataType::QAsymmS8 &&
- outputTensorInfo.GetDataType() != DataType::QAsymmU8 &&
- outputTensorInfo.GetDataType() != DataType::QSymmS8 &&
- outputTensorInfo.GetDataType() != DataType::QSymmS16)
+ if (!IsQuantizedType(outputTensorInfo.GetDataType()))
{
throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
}