diff options
author | Keith Davis <keith.davis@arm.com> | 2020-02-11 16:51:50 +0000 |
---|---|---|
committer | James Conroy <james.conroy@arm.com> | 2020-02-17 21:53:29 +0000 |
commit | 0c2eeac6347533a1d3d456aebea492f5123388f3 (patch) | |
tree | f218fc236137791c491b680dfd24fb9706c171a6 /src/backends/backendsCommon | |
parent | 4c3c1f486ab775eacb1f6455f8468f9be2c3e4f7 (diff) | |
download | armnn-0c2eeac6347533a1d3d456aebea492f5123388f3.tar.gz |
IVGCVSW-4436 Add ExecuteNetwork test for mobilenet_v2_int8
* Add QAsymmS8 to QueueDescriptor supportedTypes
* Add QSymmS8/QAsymmS8 to RefLayerSupport supportedTypes
* Some additional comments and refactoring
Change-Id: I8567314452e6e8f6f69cb6e458ee147d3fc92fab
Signed-off-by: Keith Davis <keith.davis@arm.com>
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 28 |
1 files changed, 17 insertions, 11 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index ebaf961fe8..fea72256a1 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -30,6 +30,8 @@ DataType GetBiasDataType(DataType inputDataType) return DataType::Float16; case DataType::Float32: return DataType::Float32; + case DataType::QAsymmS8: + return DataType::Signed32; case DataType::QAsymmU8: return DataType::Signed32; case DataType::QSymmS8: @@ -357,12 +359,13 @@ void ValidateWeightDataType(const TensorInfo& inputInfo, const std::string& descName) { const DataType inputType = inputInfo.GetDataType(); - if (inputType == DataType::QAsymmU8) + if (IsQuantized8BitType(inputType)) { ARMNN_NO_DEPRECATE_WARN_BEGIN const std::vector<DataType> validTypes = { DataType::QAsymmU8, + DataType::QAsymmS8, DataType::QSymmS8, DataType::QuantizedSymm8PerAxis // deprecated }; @@ -420,8 +423,7 @@ void ValidatePerAxisQuantization(const TensorInfo& inputInfo, const DataType inputDataType = inputInfo.GetDataType(); const DataType outputDataType = outputInfo.GetDataType(); - const bool canHavePerAxisQuantization = (inputDataType == DataType::QSymmS8 || - inputDataType == DataType::QAsymmU8) && inputDataType == outputDataType; + const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType; if (!canHavePerAxisQuantization) { @@ -599,6 +601,7 @@ void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { DataType::Float16, DataType::Float32, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -684,6 +687,7 @@ void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { DataType::Float16, DataType::Float32, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -1038,10 +1042,11 @@ void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { DataType::Float32, + DataType::Float16, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16, - DataType::QSymmS8, - DataType::Float16 + DataType::QSymmS8 }; ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName); @@ -1181,6 +1186,7 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co { DataType::Float32, DataType::QAsymmU8, + DataType::QAsymmS8, DataType::QSymmS16, DataType::QSymmS8, DataType::Float16 @@ -1255,6 +1261,7 @@ void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa { DataType::Float32, DataType::QAsymmU8, + DataType::QAsymmS8, DataType::QSymmS16, DataType::Float16 }; @@ -1309,6 +1316,7 @@ void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { DataType::Float32, DataType::Float16, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -1560,9 +1568,10 @@ void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::Float32, DataType::Float16, DataType::Signed32, + DataType::QSymmS16, + DataType::QAsymmS8, DataType::QAsymmU8, - DataType::QSymmS8, - DataType::QSymmS16 + DataType::QSymmS8 }; ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); @@ -2208,10 +2217,7 @@ void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); - if (outputTensorInfo.GetDataType() != DataType::QAsymmS8 && - outputTensorInfo.GetDataType() != DataType::QAsymmU8 && - outputTensorInfo.GetDataType() != DataType::QSymmS8 && - outputTensorInfo.GetDataType() != DataType::QSymmS16) + if (!IsQuantizedType(outputTensorInfo.GetDataType())) { throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type."); } |