diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 075884b2da..5057c8c4df 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -32,6 +32,8 @@ DataType GetBiasDataType(DataType inputDataType) return DataType::Float32; case DataType::QAsymmU8: return DataType::Signed32; + case DataType::QSymmS8: + return DataType::Signed32; case DataType::QSymmS16: return DataType::Signed32; default: @@ -418,8 +420,8 @@ void ValidatePerAxisQuantization(const TensorInfo& inputInfo, const DataType inputDataType = inputInfo.GetDataType(); const DataType outputDataType = outputInfo.GetDataType(); - const bool canHavePerAxisQuantization = - inputDataType == DataType::QAsymmU8 && inputDataType == outputDataType; + const bool canHavePerAxisQuantization = (inputDataType == DataType::QSymmS8 || + inputDataType == DataType::QAsymmU8) && inputDataType == outputDataType; if (!canHavePerAxisQuantization) { @@ -1038,6 +1040,7 @@ void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::Float32, DataType::QAsymmU8, DataType::QSymmS16, + DataType::QSymmS8, DataType::Float16 }; @@ -1071,6 +1074,7 @@ void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c { DataType::Float32, DataType::QAsymmU8, + DataType::QSymmS8, DataType::QSymmS16, DataType::Float16 }; @@ -1178,6 +1182,7 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co DataType::Float32, DataType::QAsymmU8, DataType::QSymmS16, + DataType::QSymmS8, DataType::Float16 }; @@ -1377,6 +1382,7 @@ void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::Float16, DataType::Float32, DataType::QAsymmU8, + DataType::QSymmS8, DataType::QSymmS16 }; @@ -1529,6 +1535,7 @@ void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::Float16, DataType::Signed32, DataType::QAsymmU8, + DataType::QSymmS8, DataType::QSymmS16 }; @@ -1554,6 +1561,7 @@ void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::Float16, DataType::Signed32, DataType::QAsymmU8, + DataType::QSymmS8, DataType::QSymmS16 }; @@ -2098,6 +2106,7 @@ void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::Float32, DataType::Signed32, DataType::QAsymmU8, + DataType::QSymmS8, DataType::QSymmS16 }; |