diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 57 |
1 files changed, 51 insertions, 6 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index bb0c21ffba..b501b3dbec 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -26,6 +26,8 @@ DataType GetBiasDataType(DataType inputDataType) { switch (inputDataType) { + case DataType::BFloat16: + return DataType::BFloat16; case DataType::Float16: return DataType::Float16; case DataType::Float32: @@ -599,6 +601,7 @@ void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmS8, @@ -628,6 +631,7 @@ void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedInputTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmU8, @@ -685,6 +689,7 @@ void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmS8, @@ -706,6 +711,7 @@ void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const // Check the supported data types std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::Boolean, @@ -842,6 +848,7 @@ void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const // Check the supported data types std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::Boolean, @@ -929,6 +936,7 @@ void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const // Check the supported data types std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::Boolean, @@ -992,6 +1000,7 @@ void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c // Check the supported data types std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QAsymmS8, @@ -1016,6 +1025,7 @@ void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co // Check the supported data types std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmU8, @@ -1042,6 +1052,7 @@ void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QAsymmS8, @@ -1077,6 +1088,7 @@ void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::QAsymmU8, DataType::QAsymmS8, @@ -1111,6 +1123,7 @@ void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInf std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmU8, @@ -1183,6 +1196,7 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::QAsymmS8, DataType::QAsymmU8, @@ -1258,6 +1272,7 @@ void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::QAsymmU8, DataType::QAsymmS8, @@ -1313,6 +1328,7 @@ void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QAsymmS8, @@ -1339,6 +1355,7 @@ void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmU8, @@ -1386,6 +1403,7 @@ void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmS8, @@ -1460,6 +1478,7 @@ void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workload // Check the supported data types std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16 }; @@ -1488,6 +1507,7 @@ void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) // Check the supported data types std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QAsymmU8, @@ -1512,6 +1532,7 @@ void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, }; @@ -1538,6 +1559,7 @@ void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const // Check the supported data types std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::Signed32, @@ -1565,6 +1587,7 @@ void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const // Check the supported data types std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::Signed32, @@ -1632,10 +1655,11 @@ void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c std::vector<DataType> supportedTypes = { - DataType::Float16, - DataType::Float32, - DataType::QAsymmU8, - DataType::QSymmS16 + DataType::BFloat16, + DataType::Float16, + DataType::Float32, + DataType::QAsymmU8, + DataType::QSymmS16 }; ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); @@ -1657,6 +1681,7 @@ void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QAsymmU8, @@ -1705,6 +1730,7 @@ void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QSymmS16 @@ -1736,6 +1762,7 @@ void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QSymmS16 @@ -2051,7 +2078,8 @@ void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::Float32, DataType::QAsymmU8, DataType::QSymmS16, - DataType::Float16 + DataType::Float16, + DataType::BFloat16 }; ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName); @@ -2082,7 +2110,8 @@ void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons DataType::Float32, DataType::QAsymmU8, DataType::QSymmS16, - DataType::Float16 + DataType::Float16, + DataType::BFloat16 }; ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName); @@ -2110,6 +2139,7 @@ void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::Signed32, @@ -2142,6 +2172,7 @@ void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QAsymmU8, @@ -2206,6 +2237,7 @@ void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QSymmS8, @@ -2234,6 +2266,7 @@ void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QAsymmU8, @@ -2256,6 +2289,7 @@ void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmU8, @@ -2312,6 +2346,7 @@ void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::Signed32, @@ -2401,6 +2436,7 @@ void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmU8, @@ -2429,6 +2465,7 @@ void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmU8, @@ -2475,6 +2512,7 @@ void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadI const std::vector<DataType> supportedInputTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QAsymmU8, @@ -2526,6 +2564,7 @@ void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16 }; @@ -2566,6 +2605,7 @@ void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::QAsymmU8, DataType::QSymmS16 @@ -2608,6 +2648,7 @@ void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmU8, @@ -2670,6 +2711,7 @@ void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QAsymmU8, @@ -2894,6 +2936,7 @@ void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmU8, @@ -2974,6 +3017,7 @@ void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QAsymmU8, @@ -3048,6 +3092,7 @@ void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmU8, |