aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2020-03-11 14:51:27 +0000
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2020-03-13 09:49:42 +0000
commit44179c372eea9f17c96cbf50ee383e57e14d70a6 (patch)
tree2a2971c2db67426107b21d9a045cfa46a4a1663a /src/backends/backendsCommon/WorkloadData.cpp
parente9b5d2989abc8008df7ff3ea287ee896ee1121a6 (diff)
downloadarmnn-44179c372eea9f17c96cbf50ee383e57e14d70a6.tar.gz
IVGCVSW-4511 Add BFloat16 to RefLayerSupport and unit tests
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: Ifaae4d5aac468ba927b2c6a4bf31b8c8522aeb2e
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp57
1 files changed, 51 insertions, 6 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index bb0c21ffba..b501b3dbec 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -26,6 +26,8 @@ DataType GetBiasDataType(DataType inputDataType)
{
switch (inputDataType)
{
+ case DataType::BFloat16:
+ return DataType::BFloat16;
case DataType::Float16:
return DataType::Float16;
case DataType::Float32:
@@ -599,6 +601,7 @@ void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmS8,
@@ -628,6 +631,7 @@ void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedInputTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -685,6 +689,7 @@ void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmS8,
@@ -706,6 +711,7 @@ void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::Boolean,
@@ -842,6 +848,7 @@ void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::Boolean,
@@ -929,6 +936,7 @@ void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::Boolean,
@@ -992,6 +1000,7 @@ void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmS8,
@@ -1016,6 +1025,7 @@ void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -1042,6 +1052,7 @@ void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmS8,
@@ -1077,6 +1088,7 @@ void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::QAsymmU8,
DataType::QAsymmS8,
@@ -1111,6 +1123,7 @@ void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInf
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -1183,6 +1196,7 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::QAsymmS8,
DataType::QAsymmU8,
@@ -1258,6 +1272,7 @@ void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::QAsymmU8,
DataType::QAsymmS8,
@@ -1313,6 +1328,7 @@ void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmS8,
@@ -1339,6 +1355,7 @@ void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -1386,6 +1403,7 @@ void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmS8,
@@ -1460,6 +1478,7 @@ void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workload
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16
};
@@ -1488,6 +1507,7 @@ void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo)
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -1512,6 +1532,7 @@ void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
};
@@ -1538,6 +1559,7 @@ void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::Signed32,
@@ -1565,6 +1587,7 @@ void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::Signed32,
@@ -1632,10 +1655,11 @@ void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
std::vector<DataType> supportedTypes =
{
- DataType::Float16,
- DataType::Float32,
- DataType::QAsymmU8,
- DataType::QSymmS16
+ DataType::BFloat16,
+ DataType::Float16,
+ DataType::Float32,
+ DataType::QAsymmU8,
+ DataType::QSymmS16
};
ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
@@ -1657,6 +1681,7 @@ void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -1705,6 +1730,7 @@ void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QSymmS16
@@ -1736,6 +1762,7 @@ void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QSymmS16
@@ -2051,7 +2078,8 @@ void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
DataType::Float32,
DataType::QAsymmU8,
DataType::QSymmS16,
- DataType::Float16
+ DataType::Float16,
+ DataType::BFloat16
};
ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
@@ -2082,7 +2110,8 @@ void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
DataType::Float32,
DataType::QAsymmU8,
DataType::QSymmS16,
- DataType::Float16
+ DataType::Float16,
+ DataType::BFloat16
};
ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
@@ -2110,6 +2139,7 @@ void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::Signed32,
@@ -2142,6 +2172,7 @@ void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -2206,6 +2237,7 @@ void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QSymmS8,
@@ -2234,6 +2266,7 @@ void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -2256,6 +2289,7 @@ void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -2312,6 +2346,7 @@ void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::Signed32,
@@ -2401,6 +2436,7 @@ void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -2429,6 +2465,7 @@ void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -2475,6 +2512,7 @@ void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadI
const std::vector<DataType> supportedInputTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -2526,6 +2564,7 @@ void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16
};
@@ -2566,6 +2605,7 @@ void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::QAsymmU8,
DataType::QSymmS16
@@ -2608,6 +2648,7 @@ void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -2670,6 +2711,7 @@ void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -2894,6 +2936,7 @@ void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -2974,6 +3017,7 @@ void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -3048,6 +3092,7 @@ void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo)
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,