diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2020-03-11 14:51:27 +0000 |
---|---|---|
committer | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2020-03-13 09:49:42 +0000 |
commit | 44179c372eea9f17c96cbf50ee383e57e14d70a6 (patch) | |
tree | 2a2971c2db67426107b21d9a045cfa46a4a1663a /src/backends/backendsCommon/WorkloadData.cpp | |
parent | e9b5d2989abc8008df7ff3ea287ee896ee1121a6 (diff) | |
download | armnn-44179c372eea9f17c96cbf50ee383e57e14d70a6.tar.gz |
IVGCVSW-4511 Add BFloat16 to RefLayerSupport and unit tests
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: Ifaae4d5aac468ba927b2c6a4bf31b8c8522aeb2e
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 57 |
1 files changed, 51 insertions, 6 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index bb0c21ffba..b501b3dbec 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -26,6 +26,8 @@ DataType GetBiasDataType(DataType inputDataType) { switch (inputDataType) { + case DataType::BFloat16: + return DataType::BFloat16; case DataType::Float16: return DataType::Float16; case DataType::Float32: @@ -599,6 +601,7 @@ void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmS8, @@ -628,6 +631,7 @@ void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedInputTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmU8, @@ -685,6 +689,7 @@ void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmS8, @@ -706,6 +711,7 @@ void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const // Check the supported data types std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::Boolean, @@ -842,6 +848,7 @@ void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const // Check the supported data types std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::Boolean, @@ -929,6 +936,7 @@ void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const // Check the supported data types std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::Boolean, @@ -992,6 +1000,7 @@ void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c // Check the supported data types std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QAsymmS8, @@ -1016,6 +1025,7 @@ void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co // Check the supported data types std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmU8, @@ -1042,6 +1052,7 @@ void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QAsymmS8, @@ -1077,6 +1088,7 @@ void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::QAsymmU8, DataType::QAsymmS8, @@ -1111,6 +1123,7 @@ void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInf std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmU8, @@ -1183,6 +1196,7 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::QAsymmS8, DataType::QAsymmU8, @@ -1258,6 +1272,7 @@ void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::QAsymmU8, DataType::QAsymmS8, @@ -1313,6 +1328,7 @@ void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QAsymmS8, @@ -1339,6 +1355,7 @@ void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmU8, @@ -1386,6 +1403,7 @@ void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmS8, @@ -1460,6 +1478,7 @@ void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workload // Check the supported data types std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16 }; @@ -1488,6 +1507,7 @@ void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) // Check the supported data types std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QAsymmU8, @@ -1512,6 +1532,7 @@ void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, }; @@ -1538,6 +1559,7 @@ void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const // Check the supported data types std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::Signed32, @@ -1565,6 +1587,7 @@ void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const // Check the supported data types std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::Signed32, @@ -1632,10 +1655,11 @@ void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c std::vector<DataType> supportedTypes = { - DataType::Float16, - DataType::Float32, - DataType::QAsymmU8, - DataType::QSymmS16 + DataType::BFloat16, + DataType::Float16, + DataType::Float32, + DataType::QAsymmU8, + DataType::QSymmS16 }; ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); @@ -1657,6 +1681,7 @@ void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QAsymmU8, @@ -1705,6 +1730,7 @@ void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QSymmS16 @@ -1736,6 +1762,7 @@ void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QSymmS16 @@ -2051,7 +2078,8 @@ void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::Float32, DataType::QAsymmU8, DataType::QSymmS16, - DataType::Float16 + DataType::Float16, + DataType::BFloat16 }; ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName); @@ -2082,7 +2110,8 @@ void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons DataType::Float32, DataType::QAsymmU8, DataType::QSymmS16, - DataType::Float16 + DataType::Float16, + DataType::BFloat16 }; ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName); @@ -2110,6 +2139,7 @@ void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::Signed32, @@ -2142,6 +2172,7 @@ void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QAsymmU8, @@ -2206,6 +2237,7 @@ void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QSymmS8, @@ -2234,6 +2266,7 @@ void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QAsymmU8, @@ -2256,6 +2289,7 @@ void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmU8, @@ -2312,6 +2346,7 @@ void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::Signed32, @@ -2401,6 +2436,7 @@ void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmU8, @@ -2429,6 +2465,7 @@ void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmU8, @@ -2475,6 +2512,7 @@ void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadI const std::vector<DataType> supportedInputTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QAsymmU8, @@ -2526,6 +2564,7 @@ void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16 }; @@ -2566,6 +2605,7 @@ void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::QAsymmU8, DataType::QSymmS16 @@ -2608,6 +2648,7 @@ void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmU8, @@ -2670,6 +2711,7 @@ void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QAsymmU8, @@ -2894,6 +2936,7 @@ void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmU8, @@ -2974,6 +3017,7 @@ void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float32, DataType::Float16, DataType::QAsymmU8, @@ -3048,6 +3092,7 @@ void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) std::vector<DataType> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmU8, |