aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/armnnUtils/BFloat16.hpp16
-rw-r--r--src/armnnUtils/QuantizeHelper.hpp17
-rw-r--r--src/backends/backendsCommon/MakeWorkloadHelper.hpp2
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp57
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp6
-rw-r--r--src/backends/backendsCommon/test/WorkloadTestUtils.hpp1
-rw-r--r--src/backends/backendsCommon/test/layerTests/ConcatTestImpl.cpp7
-rw-r--r--src/backends/backendsCommon/test/layerTests/ConcatTestImpl.hpp7
-rw-r--r--src/backends/backendsCommon/test/layerTests/Conv2dTestImpl.cpp48
-rw-r--r--src/backends/backendsCommon/test/layerTests/PadTestImpl.cpp28
-rw-r--r--src/backends/backendsCommon/test/layerTests/PadTestImpl.hpp16
-rw-r--r--src/backends/backendsCommon/test/layerTests/PermuteTestImpl.hpp104
-rw-r--r--src/backends/backendsCommon/test/layerTests/TransposeTestImpl.hpp80
-rw-r--r--src/backends/reference/RefLayerSupport.cpp211
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp17
-rw-r--r--src/backends/reference/test/RefLayerSupportTests.cpp6
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp62
-rw-r--r--src/backends/reference/workloads/Pad.cpp7
-rw-r--r--src/backends/reference/workloads/RefPadWorkload.cpp1
-rw-r--r--src/backends/reference/workloads/RefPadWorkload.hpp1
-rw-r--r--src/backends/reference/workloads/RefPermuteWorkload.cpp1
-rw-r--r--src/backends/reference/workloads/RefPermuteWorkload.hpp1
-rw-r--r--src/backends/reference/workloads/RefTransposeWorkload.cpp1
-rw-r--r--src/backends/reference/workloads/RefTransposeWorkload.hpp1
24 files changed, 530 insertions, 168 deletions
diff --git a/src/armnnUtils/BFloat16.hpp b/src/armnnUtils/BFloat16.hpp
index 5da4da559f..16ceb524c3 100644
--- a/src/armnnUtils/BFloat16.hpp
+++ b/src/armnnUtils/BFloat16.hpp
@@ -27,6 +27,17 @@ public:
m_Value = Float32ToBFloat16(v).Val();
}
+ operator float() const
+ {
+ return ToFloat32();
+ }
+
+ BFloat16& operator=(const BFloat16& other)
+ {
+ m_Value = other.Val();
+ return *this;
+ }
+
BFloat16& operator=(float v)
{
m_Value = Float32ToBFloat16(v).Val();
@@ -38,11 +49,6 @@ public:
return m_Value == r.Val();
}
- bool operator==(const float& r) const
- {
- return ToFloat32() == r;
- }
-
static BFloat16 Float32ToBFloat16(const float v)
{
if (std::isnan(v))
diff --git a/src/armnnUtils/QuantizeHelper.hpp b/src/armnnUtils/QuantizeHelper.hpp
index 6fd13fda98..596ec98f64 100644
--- a/src/armnnUtils/QuantizeHelper.hpp
+++ b/src/armnnUtils/QuantizeHelper.hpp
@@ -8,6 +8,7 @@
#include <armnn/utility/IgnoreUnused.hpp>
#include <armnn/TypesUtils.hpp>
+#include <BFloat16.hpp>
#include <Half.hpp>
#include <initializer_list>
@@ -65,6 +66,22 @@ struct SelectiveQuantizer<armnn::Half, false>
}
};
+template<>
+struct SelectiveQuantizer<armnn::BFloat16, false>
+{
+ static armnn::BFloat16 Quantize(float value, float scale, int32_t offset)
+ {
+ armnn::IgnoreUnused(scale, offset);
+ return armnn::BFloat16(value);
+ }
+
+ static float Dequantize(armnn::BFloat16 value, float scale, int32_t offset)
+ {
+ armnn::IgnoreUnused(scale, offset);
+ return value;
+ }
+};
+
template<typename T>
T SelectiveQuantize(float value, float scale, int32_t offset)
{
diff --git a/src/backends/backendsCommon/MakeWorkloadHelper.hpp b/src/backends/backendsCommon/MakeWorkloadHelper.hpp
index 7ef140e453..8abc8a6ef5 100644
--- a/src/backends/backendsCommon/MakeWorkloadHelper.hpp
+++ b/src/backends/backendsCommon/MakeWorkloadHelper.hpp
@@ -52,6 +52,7 @@ std::unique_ptr<IWorkload> MakeWorkloadHelper(const QueueDescriptorType& descrip
switch (dataType)
{
+
case DataType::Float16:
return MakeWorkloadForType<Float16Workload>::Func(descriptor, info, std::forward<Args>(args)...);
case DataType::Float32:
@@ -65,6 +66,7 @@ std::unique_ptr<IWorkload> MakeWorkloadHelper(const QueueDescriptorType& descrip
return MakeWorkloadForType<Int32Workload>::Func(descriptor, info, std::forward<Args>(args)...);
case DataType::Boolean:
return MakeWorkloadForType<BooleanWorkload>::Func(descriptor, info, std::forward<Args>(args)...);
+ case DataType::BFloat16:
case DataType::QSymmS16:
return nullptr;
default:
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index bb0c21ffba..b501b3dbec 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -26,6 +26,8 @@ DataType GetBiasDataType(DataType inputDataType)
{
switch (inputDataType)
{
+ case DataType::BFloat16:
+ return DataType::BFloat16;
case DataType::Float16:
return DataType::Float16;
case DataType::Float32:
@@ -599,6 +601,7 @@ void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmS8,
@@ -628,6 +631,7 @@ void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedInputTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -685,6 +689,7 @@ void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmS8,
@@ -706,6 +711,7 @@ void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::Boolean,
@@ -842,6 +848,7 @@ void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::Boolean,
@@ -929,6 +936,7 @@ void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::Boolean,
@@ -992,6 +1000,7 @@ void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmS8,
@@ -1016,6 +1025,7 @@ void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -1042,6 +1052,7 @@ void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmS8,
@@ -1077,6 +1088,7 @@ void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::QAsymmU8,
DataType::QAsymmS8,
@@ -1111,6 +1123,7 @@ void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInf
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -1183,6 +1196,7 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::QAsymmS8,
DataType::QAsymmU8,
@@ -1258,6 +1272,7 @@ void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::QAsymmU8,
DataType::QAsymmS8,
@@ -1313,6 +1328,7 @@ void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmS8,
@@ -1339,6 +1355,7 @@ void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -1386,6 +1403,7 @@ void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmS8,
@@ -1460,6 +1478,7 @@ void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workload
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16
};
@@ -1488,6 +1507,7 @@ void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo)
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -1512,6 +1532,7 @@ void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
};
@@ -1538,6 +1559,7 @@ void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::Signed32,
@@ -1565,6 +1587,7 @@ void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::Signed32,
@@ -1632,10 +1655,11 @@ void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
std::vector<DataType> supportedTypes =
{
- DataType::Float16,
- DataType::Float32,
- DataType::QAsymmU8,
- DataType::QSymmS16
+ DataType::BFloat16,
+ DataType::Float16,
+ DataType::Float32,
+ DataType::QAsymmU8,
+ DataType::QSymmS16
};
ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
@@ -1657,6 +1681,7 @@ void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -1705,6 +1730,7 @@ void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QSymmS16
@@ -1736,6 +1762,7 @@ void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QSymmS16
@@ -2051,7 +2078,8 @@ void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
DataType::Float32,
DataType::QAsymmU8,
DataType::QSymmS16,
- DataType::Float16
+ DataType::Float16,
+ DataType::BFloat16
};
ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
@@ -2082,7 +2110,8 @@ void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
DataType::Float32,
DataType::QAsymmU8,
DataType::QSymmS16,
- DataType::Float16
+ DataType::Float16,
+ DataType::BFloat16
};
ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
@@ -2110,6 +2139,7 @@ void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::Signed32,
@@ -2142,6 +2172,7 @@ void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -2206,6 +2237,7 @@ void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QSymmS8,
@@ -2234,6 +2266,7 @@ void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -2256,6 +2289,7 @@ void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -2312,6 +2346,7 @@ void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::Signed32,
@@ -2401,6 +2436,7 @@ void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -2429,6 +2465,7 @@ void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -2475,6 +2512,7 @@ void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadI
const std::vector<DataType> supportedInputTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -2526,6 +2564,7 @@ void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16
};
@@ -2566,6 +2605,7 @@ void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::QAsymmU8,
DataType::QSymmS16
@@ -2608,6 +2648,7 @@ void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -2670,6 +2711,7 @@ void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -2894,6 +2936,7 @@ void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -2974,6 +3017,7 @@ void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -3048,6 +3092,7 @@ void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo)
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 6ac76ecea6..2e1ce0a674 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -325,6 +325,7 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
TensorInfo biasInfo;
const TensorInfo * biasInfoPtr = nullptr;
+ static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}), DataType::BFloat16);
static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
@@ -341,6 +342,11 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
// If biases are not enabled pass a dummy tensorinfo for the validation
switch(input.GetDataType())
{
+ case DataType::BFloat16:
+ {
+ biasInfoPtr = &dummyBFloat16Bias;
+ break;
+ }
case DataType::Float16:
{
biasInfoPtr = &dummyFloat16Bias;
diff --git a/src/backends/backendsCommon/test/WorkloadTestUtils.hpp b/src/backends/backendsCommon/test/WorkloadTestUtils.hpp
index 51683335e1..df001b7530 100644
--- a/src/backends/backendsCommon/test/WorkloadTestUtils.hpp
+++ b/src/backends/backendsCommon/test/WorkloadTestUtils.hpp
@@ -95,6 +95,7 @@ inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Option
switch(weightsType.value())
{
+ case armnn::DataType::BFloat16:
case armnn::DataType::Float16:
case armnn::DataType::Float32:
return weightsType;
diff --git a/src/backends/backendsCommon/test/layerTests/ConcatTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/ConcatTestImpl.cpp
index f6f4b09f6a..1e40b42dcf 100644
--- a/src/backends/backendsCommon/test/layerTests/ConcatTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/ConcatTestImpl.cpp
@@ -2342,6 +2342,13 @@ LayerTestResult<Half, 3> ConcatFloat16Test(
return Concat3dDim1TestImpl<DataType::Float16>(workloadFactory, memoryManager, 0.0f, 0);
}
+LayerTestResult<BFloat16, 3> ConcatBFloat16Test(
+ IWorkloadFactory& workloadFactory,
+ const IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ return Concat3dDim1TestImpl<DataType::BFloat16>(workloadFactory, memoryManager, 0.0f, 0);
+}
+
LayerTestResult<uint8_t, 3> ConcatUint8DifferentQParamsTest(
IWorkloadFactory& workloadFactory,
const IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
diff --git a/src/backends/backendsCommon/test/layerTests/ConcatTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/ConcatTestImpl.hpp
index 4ce9d2921f..167a547542 100644
--- a/src/backends/backendsCommon/test/layerTests/ConcatTestImpl.hpp
+++ b/src/backends/backendsCommon/test/layerTests/ConcatTestImpl.hpp
@@ -7,8 +7,9 @@
#include "LayerTestResult.hpp"
-#include <ResolveType.hpp>
+#include <BFloat16.hpp>
#include <Half.hpp>
+#include <ResolveType.hpp>
#include <armnn/backends/IBackendInternal.hpp>
#include <backendsCommon/WorkloadFactory.hpp>
@@ -23,6 +24,10 @@ LayerTestResult<float, 3> ConcatTest(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+LayerTestResult<armnn::BFloat16, 3> ConcatBFloat16Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
LayerTestResult<armnn::Half, 3> ConcatFloat16Test(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
diff --git a/src/backends/backendsCommon/test/layerTests/Conv2dTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/Conv2dTestImpl.cpp
index 89cdd96e37..e1babd388b 100644
--- a/src/backends/backendsCommon/test/layerTests/Conv2dTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/Conv2dTestImpl.cpp
@@ -2791,6 +2791,12 @@ LayerTestResult<T, 4> CompareDepthwiseConvolution2dTestImpl(
//
// Explicit template specializations
//
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 4>
+Convolution2d3x3Dilation3x3Test<armnn::DataType::BFloat16, armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory&,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr&,
+ bool,
+ armnn::DataLayout);
template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
Convolution2d3x3Dilation3x3Test<armnn::DataType::Float32, armnn::DataType::Float32>(
@@ -2820,6 +2826,13 @@ Convolution2d2x3x3Dilation3x3Test<armnn::DataType::Float32, armnn::DataType::Flo
bool,
armnn::DataLayout);
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 4>
+Convolution2d2x3x3Dilation3x3Test<armnn::DataType::BFloat16, armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory&,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr&,
+ bool,
+ armnn::DataLayout);
+
template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 4>
Convolution2d2x3x3Dilation3x3Test<armnn::DataType::QAsymmU8, armnn::DataType::Signed32>(
armnn::IWorkloadFactory&,
@@ -2834,6 +2847,13 @@ Convolution2d2x3x3Dilation3x3Test<armnn::DataType::QSymmS16, armnn::DataType::Si
bool,
armnn::DataLayout);
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 4>
+Convolution2d2x2Dilation2x2Padding2x2Stride3x3Test<armnn::DataType::BFloat16, armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory &workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager,
+ bool biasEnabled,
+ const armnn::DataLayout layout);
+
template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
Convolution2d2x2Dilation2x2Padding2x2Stride3x3Test<armnn::DataType::Float32, armnn::DataType::Float32>(
armnn::IWorkloadFactory &workloadFactory,
@@ -2855,6 +2875,13 @@ Convolution2d2x2Dilation2x2Padding2x2Stride3x3Test<armnn::DataType::QSymmS16, ar
bool biasEnabled,
const armnn::DataLayout layout);
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 4>
+DepthwiseConvolution2d3x3Dilation3x3Test<armnn::DataType::BFloat16, armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory&,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr&,
+ bool,
+ armnn::DataLayout);
+
template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
DepthwiseConvolution2d3x3Dilation3x3Test<armnn::DataType::Float32, armnn::DataType::Float32>(
armnn::IWorkloadFactory&,
@@ -2876,6 +2903,13 @@ DepthwiseConvolution2d3x3Dilation3x3Test<armnn::DataType::QSymmS16, armnn::DataT
bool,
armnn::DataLayout);
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 4>
+DepthwiseConvolution2d2x3x3Dilation3x3Test<armnn::DataType::BFloat16, armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory&,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr&,
+ bool,
+ armnn::DataLayout);
+
template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
DepthwiseConvolution2d2x3x3Dilation3x3Test<armnn::DataType::Float32, armnn::DataType::Float32>(
armnn::IWorkloadFactory&,
@@ -2897,6 +2931,13 @@ DepthwiseConvolution2d2x3x3Dilation3x3Test<armnn::DataType::QSymmS16, armnn::Dat
bool,
armnn::DataLayout);
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 4>
+DepthwiseConvolution2dMult4Test<armnn::DataType::BFloat16, armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory &workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager,
+ bool biasEnabled,
+ const armnn::DataLayout layout);
+
template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
DepthwiseConvolution2dMult4Test<armnn::DataType::Float32, armnn::DataType::Float32>(
armnn::IWorkloadFactory &workloadFactory,
@@ -2904,6 +2945,13 @@ DepthwiseConvolution2dMult4Test<armnn::DataType::Float32, armnn::DataType::Float
bool biasEnabled,
const armnn::DataLayout layout);
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 4>
+DepthwiseConvolution2dMult2Test<armnn::DataType::BFloat16, armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory &workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager,
+ bool biasEnabled,
+ const armnn::DataLayout layout);
+
template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
DepthwiseConvolution2dMult2Test<armnn::DataType::Float32, armnn::DataType::Float32>(
armnn::IWorkloadFactory &workloadFactory,
diff --git a/src/backends/backendsCommon/test/layerTests/PadTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/PadTestImpl.cpp
index 69c651b5cd..120572ce29 100644
--- a/src/backends/backendsCommon/test/layerTests/PadTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/PadTestImpl.cpp
@@ -497,3 +497,31 @@ LayerTestResult<float, 4> PadFloat324dTest(
{
return Pad4dTestCommon<armnn::DataType::Float32>(workloadFactory, memoryManager, 0.0f, 0);
}
+
+LayerTestResult<armnn::BFloat16, 2> PadBFloat162dTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ return Pad2dTestCommon<armnn::DataType::BFloat16>(workloadFactory, memoryManager, 0.0f, 0);
+}
+
+LayerTestResult<armnn::BFloat16, 2> PadBFloat162dCustomPaddingTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ return Pad2dTestCommon<armnn::DataType::BFloat16>(workloadFactory, memoryManager, 0.0f, 0, 1.0f);
+}
+
+LayerTestResult<armnn::BFloat16, 3> PadBFloat163dTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ return Pad3dTestCommon<armnn::DataType::BFloat16>(workloadFactory, memoryManager, 0.0f, 0);
+}
+
+LayerTestResult<armnn::BFloat16, 4> PadBFloat164dTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ return Pad4dTestCommon<armnn::DataType::BFloat16>(workloadFactory, memoryManager, 0.0f, 0);
+}
diff --git a/src/backends/backendsCommon/test/layerTests/PadTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/PadTestImpl.hpp
index bc514881d6..34aa6c66a3 100644
--- a/src/backends/backendsCommon/test/layerTests/PadTestImpl.hpp
+++ b/src/backends/backendsCommon/test/layerTests/PadTestImpl.hpp
@@ -67,3 +67,19 @@ LayerTestResult<float, 3> PadFloat323dTest(
LayerTestResult<float, 4> PadFloat324dTest(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
+LayerTestResult<armnn::BFloat16, 2> PadBFloat162dTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
+LayerTestResult<armnn::BFloat16, 2> PadBFloat162dCustomPaddingTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
+LayerTestResult<armnn::BFloat16, 3> PadBFloat163dTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
+LayerTestResult<armnn::BFloat16, 4> PadBFloat164dTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
diff --git a/src/backends/backendsCommon/test/layerTests/PermuteTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/PermuteTestImpl.hpp
index 71e15334e7..96d4ec8f0f 100644
--- a/src/backends/backendsCommon/test/layerTests/PermuteTestImpl.hpp
+++ b/src/backends/backendsCommon/test/layerTests/PermuteTestImpl.hpp
@@ -72,27 +72,31 @@ LayerTestResult<T, 4> SimplePermuteTest(
outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
// Set quantization parameters if the requested type is a quantized type.
+ float qScale = 0.5f;
+ int32_t qOffset = 5;
if(armnn::IsQuantizedType<T>())
{
- inputTensorInfo.SetQuantizationScale(0.5f);
- inputTensorInfo.SetQuantizationOffset(5);
- outputTensorInfo.SetQuantizationScale(0.5f);
- outputTensorInfo.SetQuantizationOffset(5);
+ inputTensorInfo.SetQuantizationScale(qScale);
+ inputTensorInfo.SetQuantizationOffset(qOffset);
+ outputTensorInfo.SetQuantizationScale(qScale);
+ outputTensorInfo.SetQuantizationOffset(qOffset);
}
- std::vector<T> input = std::vector<T>(
+ std::vector<T> input = armnnUtils::QuantizedVector<T>(
{
1, 2,
3, 4,
5, 6,
7, 8
- });
+ },
+ qScale, qOffset);
- std::vector<T> outputExpected = std::vector<T>(
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>(
{
1, 5, 2, 6,
3, 7, 4, 8
- });
+ },
+ qScale, qOffset);
return SimplePermuteTestImpl<T>(workloadFactory, memoryManager,
descriptor, inputTensorInfo,
@@ -117,28 +121,32 @@ LayerTestResult<T, 4> PermuteValueSet1Test(
outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
// Set quantization parameters if the requested type is a quantized type.
+ float qScale = 0.5f;
+ int32_t qOffset = 5;
if(armnn::IsQuantizedType<T>())
{
- inputTensorInfo.SetQuantizationScale(0.5f);
- inputTensorInfo.SetQuantizationOffset(5);
- outputTensorInfo.SetQuantizationScale(0.5f);
- outputTensorInfo.SetQuantizationOffset(5);
+ inputTensorInfo.SetQuantizationScale(qScale);
+ inputTensorInfo.SetQuantizationOffset(qOffset);
+ outputTensorInfo.SetQuantizationScale(qScale);
+ outputTensorInfo.SetQuantizationOffset(qOffset);
}
- std::vector<T> input = std::vector<T>(
+ std::vector<T> input = armnnUtils::QuantizedVector<T>(
{
1, 2, 3,
11, 12, 13,
21, 22, 23,
31, 32, 33
- });
+ },
+ qScale, qOffset);
- std::vector<T> outputExpected = std::vector<T>(
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>(
{
1, 11, 21, 31,
2, 12, 22, 32,
3, 13, 23, 33
- });
+ },
+ qScale, qOffset);
return SimplePermuteTestImpl<T>(workloadFactory, memoryManager,
descriptor, inputTensorInfo,
@@ -163,28 +171,32 @@ LayerTestResult<T, 4> PermuteValueSet2Test(
outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
// Set quantization parameters if the requested type is a quantized type.
+ float qScale = 0.5f;
+ int32_t qOffset = 5;
if(armnn::IsQuantizedType<T>())
{
- inputTensorInfo.SetQuantizationScale(0.5f);
- inputTensorInfo.SetQuantizationOffset(5);
- outputTensorInfo.SetQuantizationScale(0.5f);
- outputTensorInfo.SetQuantizationOffset(5);
+ inputTensorInfo.SetQuantizationScale(qScale);
+ inputTensorInfo.SetQuantizationOffset(qOffset);
+ outputTensorInfo.SetQuantizationScale(qScale);
+ outputTensorInfo.SetQuantizationOffset(qOffset);
}
- std::vector<T> input = std::vector<T>(
+ std::vector<T> input = armnnUtils::QuantizedVector<T>(
{
1, 11, 21, 31,
2, 12, 22, 32,
3, 13, 23, 33
- });
+ },
+ qScale, qOffset);
- std::vector<T> outputExpected = std::vector<T>(
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>(
{
1, 2, 3,
11, 12, 13,
21, 22, 23,
31, 32, 33,
- });
+ },
+ qScale, qOffset);
return SimplePermuteTestImpl<T>(workloadFactory, memoryManager,
descriptor, inputTensorInfo,
@@ -209,30 +221,34 @@ LayerTestResult<T, 4> PermuteValueSet3Test(
outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
// Set quantization parameters if the requested type is a quantized type.
+ float qScale = 0.5f;
+ int32_t qOffset = 5;
if(armnn::IsQuantizedType<T>())
{
- inputTensorInfo.SetQuantizationScale(0.5f);
- inputTensorInfo.SetQuantizationOffset(5);
- outputTensorInfo.SetQuantizationScale(0.5f);
- outputTensorInfo.SetQuantizationOffset(5);
+ inputTensorInfo.SetQuantizationScale(qScale);
+ inputTensorInfo.SetQuantizationOffset(qOffset);
+ outputTensorInfo.SetQuantizationScale(qScale);
+ outputTensorInfo.SetQuantizationOffset(qOffset);
}
- std::vector<T> input = std::vector<T>(
- {
- 1, 2, 3,
- 11, 12, 13,
- 21, 22, 23,
- 31, 32, 33,
- 41, 42, 43,
- 51, 52, 53
- });
-
- std::vector<T> outputExpected = std::vector<T>(
- {
- 1, 11, 21, 31, 41, 51,
- 2, 12, 22, 32, 42, 52,
- 3, 13, 23, 33, 43, 53
- });
+ std::vector<T> input = armnnUtils::QuantizedVector<T>(
+ {
+ 1, 2, 3,
+ 11, 12, 13,
+ 21, 22, 23,
+ 31, 32, 33,
+ 41, 42, 43,
+ 51, 52, 53
+ },
+ qScale, qOffset);
+
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>(
+ {
+ 1, 11, 21, 31, 41, 51,
+ 2, 12, 22, 32, 42, 52,
+ 3, 13, 23, 33, 43, 53
+ },
+ qScale, qOffset);
return SimplePermuteTestImpl<T>(workloadFactory, memoryManager,
descriptor, inputTensorInfo,
diff --git a/src/backends/backendsCommon/test/layerTests/TransposeTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/TransposeTestImpl.hpp
index 0e0f317a3e..5721952066 100644
--- a/src/backends/backendsCommon/test/layerTests/TransposeTestImpl.hpp
+++ b/src/backends/backendsCommon/test/layerTests/TransposeTestImpl.hpp
@@ -72,27 +72,31 @@ LayerTestResult<T, 4> SimpleTransposeTest(
outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
// Set quantization parameters if the requested type is a quantized type.
+ float qScale = 0.5f;
+ int32_t qOffset = 5;
if(armnn::IsQuantizedType<T>())
{
- inputTensorInfo.SetQuantizationScale(0.5f);
- inputTensorInfo.SetQuantizationOffset(5);
- outputTensorInfo.SetQuantizationScale(0.5f);
- outputTensorInfo.SetQuantizationOffset(5);
+ inputTensorInfo.SetQuantizationScale(qScale);
+ inputTensorInfo.SetQuantizationOffset(qOffset);
+ outputTensorInfo.SetQuantizationScale(qScale);
+ outputTensorInfo.SetQuantizationOffset(qOffset);
}
- std::vector<T> input = std::vector<T>(
+ std::vector<T> input = armnnUtils::QuantizedVector<T>(
{
1, 2,
3, 4,
5, 6,
7, 8
- });
+ },
+ qScale, qOffset);
- std::vector<T> outputExpected = std::vector<T>(
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>(
{
1, 5, 2, 6,
3, 7, 4, 8
- });
+ },
+ qScale, qOffset);
return SimpleTransposeTestImpl<T>(workloadFactory, memoryManager,
descriptor, inputTensorInfo,
@@ -117,28 +121,32 @@ LayerTestResult<T, 4> TransposeValueSet1Test(
outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
// Set quantization parameters if the requested type is a quantized type.
+ float qScale = 0.5f;
+ int32_t qOffset = 5;
if(armnn::IsQuantizedType<T>())
{
- inputTensorInfo.SetQuantizationScale(0.5f);
- inputTensorInfo.SetQuantizationOffset(5);
- outputTensorInfo.SetQuantizationScale(0.5f);
- outputTensorInfo.SetQuantizationOffset(5);
+ inputTensorInfo.SetQuantizationScale(qScale);
+ inputTensorInfo.SetQuantizationOffset(qOffset);
+ outputTensorInfo.SetQuantizationScale(qScale);
+ outputTensorInfo.SetQuantizationOffset(qOffset);
}
- std::vector<T> input = std::vector<T>(
+ std::vector<T> input = armnnUtils::QuantizedVector<T>(
{
1, 2, 3,
11, 12, 13,
21, 22, 23,
31, 32, 33
- });
+ },
+ qScale, qOffset);
- std::vector<T> outputExpected = std::vector<T>(
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>(
{
1, 11, 21, 31,
2, 12, 22, 32,
3, 13, 23, 33
- });
+ },
+ qScale, qOffset);
return SimpleTransposeTestImpl<T>(workloadFactory, memoryManager,
descriptor, inputTensorInfo,
@@ -163,28 +171,32 @@ LayerTestResult<T, 4> TransposeValueSet2Test(
outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
// Set quantization parameters if the requested type is a quantized type.
+ float qScale = 0.5f;
+ int32_t qOffset = 5;
if(armnn::IsQuantizedType<T>())
{
- inputTensorInfo.SetQuantizationScale(0.5f);
- inputTensorInfo.SetQuantizationOffset(5);
- outputTensorInfo.SetQuantizationScale(0.5f);
- outputTensorInfo.SetQuantizationOffset(5);
+ inputTensorInfo.SetQuantizationScale(qScale);
+ inputTensorInfo.SetQuantizationOffset(qOffset);
+ outputTensorInfo.SetQuantizationScale(qScale);
+ outputTensorInfo.SetQuantizationOffset(qOffset);
}
- std::vector<T> input = std::vector<T>(
+ std::vector<T> input = armnnUtils::QuantizedVector<T>(
{
1, 11, 21, 31,
2, 12, 22, 32,
3, 13, 23, 33
- });
+ },
+ qScale, qOffset);
- std::vector<T> outputExpected = std::vector<T>(
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>(
{
1, 2, 3,
11, 12, 13,
21, 22, 23,
31, 32, 33,
- });
+ },
+ qScale, qOffset);
return SimpleTransposeTestImpl<T>(workloadFactory, memoryManager,
descriptor, inputTensorInfo,
@@ -209,15 +221,17 @@ LayerTestResult<T, 4> TransposeValueSet3Test(
outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
// Set quantization parameters if the requested type is a quantized type.
+ float qScale = 0.5f;
+ int32_t qOffset = 5;
if(armnn::IsQuantizedType<T>())
{
- inputTensorInfo.SetQuantizationScale(0.5f);
- inputTensorInfo.SetQuantizationOffset(5);
- outputTensorInfo.SetQuantizationScale(0.5f);
- outputTensorInfo.SetQuantizationOffset(5);
+ inputTensorInfo.SetQuantizationScale(qScale);
+ inputTensorInfo.SetQuantizationOffset(qOffset);
+ outputTensorInfo.SetQuantizationScale(qScale);
+ outputTensorInfo.SetQuantizationOffset(qOffset);
}
- std::vector<T> input = std::vector<T>(
+ std::vector<T> input = armnnUtils::QuantizedVector<T>(
{
1, 2, 3,
11, 12, 13,
@@ -225,14 +239,16 @@ LayerTestResult<T, 4> TransposeValueSet3Test(
31, 32, 33,
41, 42, 43,
51, 52, 53
- });
+ },
+ qScale, qOffset);
- std::vector<T> outputExpected = std::vector<T>(
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>(
{
1, 11, 21, 31, 41, 51,
2, 12, 22, 32, 42, 52,
3, 13, 23, 33, 43, 53
- });
+ },
+ qScale, qOffset);
return SimpleTransposeTestImpl<T>(workloadFactory, memoryManager,
descriptor, inputTensorInfo,
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index cb94955e7a..9dc576cac8 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -79,6 +79,7 @@ bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
// Define supported types.
std::array<DataType,6> supportedTypes = {
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmS8,
@@ -145,6 +146,7 @@ bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
bool supported = true;
std::array<DataType,6> supportedTypes = {
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmS8,
@@ -179,8 +181,9 @@ bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const
{
IgnoreUnused(descriptor);
- std::array<DataType, 4> supportedTypes =
+ std::array<DataType, 5> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::QAsymmU8,
DataType::QSymmS16,
@@ -208,8 +211,9 @@ bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
{
IgnoreUnused(descriptor);
- std::array<DataType, 4> supportedTypes =
+ std::array<DataType, 5> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -256,12 +260,13 @@ bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
std::string outputTensorStr = "output";
// Define supported types.
- std::array<DataType,4> supportedTypes =
+ std::array<DataType,5> supportedTypes =
{
- DataType::Float32,
- DataType::Float16,
- DataType::QAsymmU8,
- DataType::QSymmS16
+ DataType::BFloat16,
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QAsymmU8,
+ DataType::QSymmS16
};
supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
@@ -298,8 +303,9 @@ bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
{
IgnoreUnused(descriptor);
- std::array<DataType, 4> supportedInputTypes =
+ std::array<DataType, 5> supportedInputTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -327,13 +333,14 @@ bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inp
IgnoreUnused(descriptor);
bool supported = true;
- std::array<DataType,5> supportedTypes =
+ std::array<DataType,6> supportedTypes =
{
- DataType::Float32,
- DataType::Float16,
- DataType::QAsymmU8,
- DataType::QAsymmS8,
- DataType::QSymmS16
+ DataType::BFloat16,
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QAsymmU8,
+ DataType::QAsymmS8,
+ DataType::QSymmS16
};
supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
@@ -354,8 +361,9 @@ bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inp
bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- std::array<DataType,6> supportedTypes =
+ std::array<DataType,7> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Signed32,
DataType::QAsymmU8,
@@ -418,8 +426,9 @@ bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
bool supported = true;
// Define supported types.
- std::array<DataType,6> supportedTypes =
+ std::array<DataType,7> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -464,8 +473,9 @@ bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
if (biases.has_value())
{
- std::array<DataType,3> biasesSupportedTypes =
+ std::array<DataType,4> biasesSupportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::Signed32
@@ -516,8 +526,9 @@ bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
IgnoreUnused(descriptor);
bool supported = true;
- std::array<DataType,4> supportedTypes =
+ std::array<DataType,5> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -546,8 +557,9 @@ bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
bool supported = true;
// Define supported types.
- std::array<DataType,6> supportedTypes =
+ std::array<DataType,7> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QSymmS8,
@@ -592,8 +604,9 @@ bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
if (biases.has_value())
{
- std::array<DataType,3> biasesSupportedTypes =
+ std::array<DataType,4> biasesSupportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::Signed32
@@ -629,7 +642,8 @@ bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
"Reference dequantize: per-axis quantized input not support .");
- std::array<DataType,2> supportedOutputTypes = {
+ std::array<DataType,3> supportedOutputTypes = {
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16
};
@@ -658,8 +672,9 @@ bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncod
bool supported = true;
- std::array<DataType,3> supportedInputTypes =
+ std::array<DataType,4> supportedInputTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::QAsymmU8,
DataType::QSymmS16
@@ -691,7 +706,8 @@ bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
{
bool supported = true;
- std::array<DataType,4> supportedTypes = {
+ std::array<DataType,5> supportedTypes = {
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -726,8 +742,9 @@ bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
{
IgnoreUnused(descriptor);
- std::array<DataType, 4> supportedTypes =
+ std::array<DataType, 5> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -789,8 +806,9 @@ bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
IgnoreUnused(output);
bool supported = true;
- std::array<DataType,3> supportedTypes =
+ std::array<DataType,4> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QSymmS16
@@ -815,13 +833,14 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
bool supported = true;
// Define supported types.
- std::array<DataType,5> supportedTypes =
+ std::array<DataType,6> supportedTypes =
{
- DataType::Float32,
- DataType::Float16,
- DataType::QAsymmU8,
- DataType::QAsymmS8,
- DataType::QSymmS16
+ DataType::BFloat16,
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QAsymmU8,
+ DataType::QAsymmS8,
+ DataType::QSymmS16
};
supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
@@ -863,9 +882,10 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
if (descriptor.m_BiasEnabled)
{
// Defined supported types for bias
- std::array<DataType, 3>
+ std::array<DataType, 4>
supportedBiasTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::Signed32
@@ -891,8 +911,9 @@ bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
armnn::Optional<std::string&> reasonIfUnsupported) const
{
bool supported = true;
- std::array<DataType,4> supportedTypes =
+ std::array<DataType,5> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -939,8 +960,9 @@ bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
{
IgnoreUnused(descriptor);
// Define supported types
- std::array<DataType, 4> supportedTypes =
+ std::array<DataType, 3> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16
};
@@ -970,8 +992,9 @@ bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
{
IgnoreUnused(descriptor);
// Define supported types
- std::array<DataType, 4> supportedTypes =
+ std::array<DataType, 5> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -1003,10 +1026,11 @@ bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
{
IgnoreUnused(descriptor);
- std::array<DataType, 2> supportedTypes =
+ std::array<DataType, 3> supportedTypes =
{
- DataType::Float32,
- DataType::Float16
+ DataType::BFloat16,
+ DataType::Float32,
+ DataType::Float16
};
bool supported = true;
@@ -1038,7 +1062,8 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
bool supported = true;
- std::array<DataType,2> supportedTypes = {
+ std::array<DataType,3> supportedTypes = {
+ DataType::BFloat16,
DataType::Float32,
DataType::QSymmS16
};
@@ -1139,7 +1164,8 @@ bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
{
bool supported = true;
- std::array<DataType,5> supportedTypes = {
+ std::array<DataType,6> supportedTypes = {
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmS8,
@@ -1177,8 +1203,9 @@ bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
std::string meanLayerStr = "Mean";
std::string outputTensorStr = "output";
- std::array<DataType,4> supportedTypes =
+ std::array<DataType,5> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -1243,8 +1270,9 @@ bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
{
bool supported = true;
- std::array<DataType,5> supportedTypes =
+ std::array<DataType,6> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -1271,7 +1299,8 @@ bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
{
bool supported = true;
- std::array<DataType,4> supportedTypes = {
+ std::array<DataType,5> supportedTypes = {
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -1307,6 +1336,7 @@ bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
bool supported = true;
std::array<DataType,6> supportedTypes = {
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -1343,8 +1373,9 @@ bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
IgnoreUnused(descriptor);
// Define supported types
- std::array<DataType, 4> supportedTypes =
+ std::array<DataType, 5> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -1381,8 +1412,9 @@ bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
bool supported = true;
// Define supported output and inputs types.
- std::array<DataType,4> supportedTypes =
+ std::array<DataType,5> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -1410,8 +1442,9 @@ bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
bool supported = true;
// Define supported output and inputs types.
- std::array<DataType, 4> supportedTypes =
+ std::array<DataType, 5> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -1439,8 +1472,9 @@ bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
bool supported = true;
// Define supported output and inputs types.
- std::array<DataType,5> supportedTypes =
+ std::array<DataType,6> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmS8,
@@ -1467,7 +1501,8 @@ bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
bool supported = true;
// Define supported input types.
- std::array<DataType,6> supportedInputTypes = {
+ std::array<DataType,7> supportedInputTypes = {
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmS8,
@@ -1505,6 +1540,7 @@ bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
// Define supported output types.
std::array<DataType,7> supportedOutputTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::Signed32,
@@ -1522,8 +1558,9 @@ bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Optional<std::string&> reasonIfUnsupported) const
{
bool supported = true;
- std::array<DataType,4> supportedTypes =
+ std::array<DataType,5> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -1549,8 +1586,9 @@ bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
{
IgnoreUnused(descriptor);
bool supported = true;
- std::array<DataType,5> supportedTypes =
+ std::array<DataType,6> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -1588,8 +1626,9 @@ bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
IgnoreUnused(descriptor);
bool supported = true;
- std::array<DataType, 3> supportedTypes =
+ std::array<DataType, 4> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::QAsymmU8,
DataType::QSymmS16
@@ -1614,14 +1653,15 @@ bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
{
IgnoreUnused(descriptor);
bool supported = true;
- std::array<DataType,6> supportedTypes =
+ std::array<DataType,7> supportedTypes =
{
- DataType::Float32,
- DataType::Float16,
- DataType::QSymmS8,
- DataType::QAsymmS8,
- DataType::QAsymmU8,
- DataType::QSymmS16
+ DataType::BFloat16,
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QSymmS8,
+ DataType::QAsymmS8,
+ DataType::QAsymmU8,
+ DataType::QSymmS16
};
supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
@@ -1643,12 +1683,13 @@ bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
{
IgnoreUnused(descriptor);
bool supported = true;
- std::array<DataType,4> supportedTypes =
+ std::array<DataType,5> supportedTypes =
{
- DataType::Float32,
- DataType::Float16,
- DataType::QAsymmU8,
- DataType::QSymmS16
+ DataType::BFloat16,
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QAsymmU8,
+ DataType::QSymmS16
};
supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
@@ -1672,8 +1713,9 @@ bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
IgnoreUnused(descriptor);
bool supported = true;
- std::array<DataType,4> supportedTypes =
+ std::array<DataType,5> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -1698,8 +1740,9 @@ bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
{
IgnoreUnused(descriptor);
bool supported = true;
- std::array<DataType,4> supportedTypes =
+ std::array<DataType,5> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -1719,8 +1762,9 @@ bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
{
IgnoreUnused(descriptor);
bool supported = true;
- std::array<DataType,4> supportedTypes =
+ std::array<DataType,5> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -1749,8 +1793,9 @@ bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inp
IgnoreUnused(descriptor);
bool supported = true;
- std::array<DataType,4> supportedTypes =
+ std::array<DataType,5> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -1780,8 +1825,9 @@ bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
IgnoreUnused(descriptor);
bool supported = true;
- std::array<DataType,3> supportedTypes =
+ std::array<DataType,4> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::QAsymmU8,
DataType::QSymmS16
@@ -1806,7 +1852,8 @@ bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
{
bool supported = true;
- std::array<DataType,4> supportedTypes = {
+ std::array<DataType,5> supportedTypes = {
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -1841,8 +1888,9 @@ bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
{
bool supported = true;
- std::array<DataType, 4> supportedTypes
+ std::array<DataType, 5> supportedTypes
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -1877,12 +1925,13 @@ bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
IgnoreUnused(descriptor);
bool supported = true;
- std::array<DataType,4> supportedTypes =
+ std::array<DataType,5> supportedTypes =
{
- DataType::Float32,
- DataType::Float16,
- DataType::QAsymmU8,
- DataType::QSymmS16
+ DataType::BFloat16,
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QAsymmU8,
+ DataType::QSymmS16
};
supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
@@ -1922,11 +1971,12 @@ bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
if (biases.has_value())
{
- std::array<DataType,3> biasesSupportedTypes =
+ std::array<DataType,4> biasesSupportedTypes =
{
- DataType::Float32,
- DataType::Float16,
- DataType::Signed32
+ DataType::BFloat16,
+ DataType::Float32,
+ DataType::Float16,
+ DataType::Signed32
};
supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
"Reference TransposeConvolution2d: biases is not a supported type.");
@@ -1944,8 +1994,9 @@ bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
bool supported = true;
// Define supported output and inputs types.
- std::array<DataType, 4> supportedTypes =
+ std::array<DataType, 5> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 52d71df936..1d82421490 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -50,6 +50,11 @@ bool IsSigned32(const WorkloadInfo& info)
return IsDataType<DataType::Signed32>(info);
}
+bool IsBFloat16(const WorkloadInfo& info)
+{
+ return IsDataType<DataType::BFloat16>(info);
+}
+
bool IsFloat16(const WorkloadInfo& info)
{
return IsDataType<DataType::Float16>(info);
@@ -441,6 +446,10 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePad(const PadQueueDescripto
{
return std::make_unique<RefPadFloat16Workload>(descriptor, info);
}
+ else if (IsBFloat16(info))
+ {
+ return std::make_unique<RefPadBFloat16Workload>(descriptor, info);
+ }
return MakeWorkload<RefPadFloat32Workload, RefPadQAsymm8Workload>(descriptor, info);
}
@@ -451,6 +460,10 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePermute(const PermuteQueueD
{
return std::make_unique<RefPermuteQSymm16Workload>(descriptor, info);
}
+ else if (IsBFloat16(info))
+ {
+ return std::make_unique<RefPermuteBFloat16Workload>(descriptor, info);
+ }
return MakeWorkloadHelper<RefPermuteFloat16Workload, RefPermuteFloat32Workload, RefPermuteQAsymm8Workload,
NullWorkload, NullWorkload, NullWorkload>(descriptor, info);
}
@@ -568,6 +581,10 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateTranspose(const TransposeQu
{
return std::make_unique<RefTransposeQSymm16Workload>(descriptor, info);
}
+ else if (IsBFloat16(info))
+ {
+ return std::make_unique<RefTransposeBFloat16Workload>(descriptor, info);
+ }
return MakeWorkloadHelper<RefTransposeFloat16Workload, RefTransposeFloat32Workload, RefTransposeQAsymm8Workload,
NullWorkload, NullWorkload, NullWorkload>(descriptor, info);
}
diff --git a/src/backends/reference/test/RefLayerSupportTests.cpp b/src/backends/reference/test/RefLayerSupportTests.cpp
index f0c49aceb0..ab749c1a5c 100644
--- a/src/backends/reference/test/RefLayerSupportTests.cpp
+++ b/src/backends/reference/test/RefLayerSupportTests.cpp
@@ -48,6 +48,12 @@ BOOST_AUTO_TEST_CASE(IsLayerSupportedReferenceAddition)
BOOST_CHECK(supportChecker.IsAdditionSupported(in0, in1, out, reasonNotSupported));
}
+BOOST_AUTO_TEST_CASE(IsLayerSupportedBFloat16Reference)
+{
+ armnn::RefWorkloadFactory factory;
+ IsLayerSupportedTests<armnn::RefWorkloadFactory, armnn::DataType::BFloat16>(&factory);
+}
+
BOOST_AUTO_TEST_CASE(IsLayerSupportedFloat16Reference)
{
armnn::RefWorkloadFactory factory;
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 40bf600331..a6bfe3575c 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -70,6 +70,14 @@ ARMNN_AUTO_TEST_CASE(SimpleConvolution2dAsymmetricPaddingNhwc,
ARMNN_AUTO_TEST_CASE(SimpleConvolution2dSquareNhwc, SimpleConvolution2d3x3NhwcTest, false)
+ARMNN_AUTO_TEST_CASE(Convolution2d3x3Dilation3x3BFloat16,
+ Convolution2d3x3Dilation3x3Test<DataType::BFloat16, DataType::BFloat16>,
+ false,
+ DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(Convolution2d3x3Dilation3x3NhwcBFloat16,
+ Convolution2d3x3Dilation3x3Test<DataType::BFloat16, DataType::BFloat16>,
+ false,
+ DataLayout::NHWC)
ARMNN_AUTO_TEST_CASE(Convolution2d3x3Dilation3x3,
Convolution2d3x3Dilation3x3Test<DataType::Float32, DataType::Float32>,
false,
@@ -95,6 +103,14 @@ ARMNN_AUTO_TEST_CASE(Convolution2d3x3Dilation3x3NhwcInt16,
false,
DataLayout::NHWC)
+ARMNN_AUTO_TEST_CASE(Convolution2d2x3x3Dilation3x3BFloat16,
+ Convolution2d2x3x3Dilation3x3Test<DataType::BFloat16, DataType::BFloat16>,
+ false,
+ DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(Convolution2d2x3x3Dilation3x3NhwcBFloat16,
+ Convolution2d2x3x3Dilation3x3Test<DataType::BFloat16, DataType::BFloat16>,
+ false,
+ DataLayout::NHWC)
ARMNN_AUTO_TEST_CASE(Convolution2d2x3x3Dilation3x3,
Convolution2d2x3x3Dilation3x3Test<DataType::Float32, DataType::Float32>,
false,
@@ -120,6 +136,14 @@ ARMNN_AUTO_TEST_CASE(Convolution2d2x3x3Dilation3x3NhwcInt16,
false,
DataLayout::NHWC)
+ARMNN_AUTO_TEST_CASE(Convolution2d2x2Dilation2x2Padding2x2Stride3x3BFloat16,
+ Convolution2d2x2Dilation2x2Padding2x2Stride3x3Test<DataType::BFloat16, DataType::BFloat16>,
+ false,
+ DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(Convolution2d2x2Dilation2x2Padding2x2Stride3x3NhwcBFloat16,
+ Convolution2d2x2Dilation2x2Padding2x2Stride3x3Test<DataType::BFloat16, DataType::BFloat16>,
+ false,
+ DataLayout::NHWC)
ARMNN_AUTO_TEST_CASE(Convolution2d2x2Dilation2x2Padding2x2Stride3x3,
Convolution2d2x2Dilation2x2Padding2x2Stride3x3Test<DataType::Float32, DataType::Float32>,
false,
@@ -179,6 +203,14 @@ ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2d3x3Dilation3x3Nhwc,
DepthwiseConvolution2d3x3Dilation3x3Test<DataType::Float32, DataType::Float32>,
false,
DataLayout::NHWC)
+ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2d3x3Dilation3x3BFloat16,
+ DepthwiseConvolution2d3x3Dilation3x3Test<DataType::BFloat16, DataType::BFloat16>,
+ false,
+ DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2d3x3Dilation3x3NhwcBFloat16,
+ DepthwiseConvolution2d3x3Dilation3x3Test<DataType::BFloat16, DataType::BFloat16>,
+ false,
+ DataLayout::NHWC)
ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2d3x3Dilation3x3Uint8,
DepthwiseConvolution2d3x3Dilation3x3Test<DataType::QAsymmU8, DataType::Signed32>,
false,
@@ -204,6 +236,14 @@ ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2d2x3x3Dilation3x3Nhwc,
DepthwiseConvolution2d2x3x3Dilation3x3Test<DataType::Float32, DataType::Float32>,
false,
DataLayout::NHWC)
+ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2d2x3x3Dilation3x3BFloat16,
+ DepthwiseConvolution2d2x3x3Dilation3x3Test<DataType::BFloat16, DataType::BFloat16>,
+ false,
+ DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2d2x3x3Dilation3x3NhwcBFloat16,
+ DepthwiseConvolution2d2x3x3Dilation3x3Test<DataType::BFloat16, DataType::BFloat16>,
+ false,
+ DataLayout::NHWC)
ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2d2x3x3Dilation3x3Uint8,
DepthwiseConvolution2d2x3x3Dilation3x3Test<DataType::QAsymmU8, DataType::Signed32>,
false,
@@ -228,6 +268,14 @@ ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2dMult2,
DepthwiseConvolution2dMult2Test<armnn::DataType::Float32, armnn::DataType::Float32>,
false,
armnn::DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2dMult4BFloat16,
+ DepthwiseConvolution2dMult4Test<armnn::DataType::BFloat16, armnn::DataType::BFloat16>,
+ false,
+ armnn::DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2dMult2BFloat16,
+ DepthwiseConvolution2dMult2Test<armnn::DataType::BFloat16, armnn::DataType::BFloat16>,
+ false,
+ armnn::DataLayout::NCHW)
ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2dDepthMul1,
DepthwiseConvolution2dDepthMul1Test, true, DataLayout::NCHW)
@@ -496,6 +544,7 @@ ARMNN_AUTO_TEST_CASE(CopyViaSplitterInt16, CopyViaSplitterInt16Test)
// Concat
ARMNN_AUTO_TEST_CASE(SimpleConcat, ConcatTest)
+ARMNN_AUTO_TEST_CASE(ConcatBFloat16, ConcatBFloat16Test)
ARMNN_AUTO_TEST_CASE(ConcatFloat16, ConcatFloat16Test)
ARMNN_AUTO_TEST_CASE(ConcatUint8, ConcatUint8Test)
ARMNN_AUTO_TEST_CASE(ConcatUint8DifferentQParams, ConcatUint8DifferentQParamsTest)
@@ -950,6 +999,11 @@ ARMNN_AUTO_TEST_CASE(LogSoftmaxFloat16_3, LogSoftmaxTest3<DataType::Float16>)
ARMNN_AUTO_TEST_CASE(LogSoftmaxFloat16_4, LogSoftmaxTest4<DataType::Float16>)
// Pad
+ARMNN_AUTO_TEST_CASE(PadBFloat162d, PadBFloat162dTest)
+ARMNN_AUTO_TEST_CASE(PadBFloat162dCustomPadding, PadBFloat162dCustomPaddingTest)
+ARMNN_AUTO_TEST_CASE(PadBFloat163d, PadBFloat163dTest)
+ARMNN_AUTO_TEST_CASE(PadBFloat164d, PadBFloat164dTest)
+
ARMNN_AUTO_TEST_CASE(PadFloat322d, PadFloat322dTest)
ARMNN_AUTO_TEST_CASE(PadFloat322dCustomPadding, PadFloat322dCustomPaddingTest)
ARMNN_AUTO_TEST_CASE(PadFloat323d, PadFloat323dTest)
@@ -1040,6 +1094,10 @@ ARMNN_AUTO_TEST_CASE(Rsqrt2dQuantisedSymm16, Rsqrt2dTest<DataType::QSymmS16>)
ARMNN_AUTO_TEST_CASE(Rsqrt3dQuantisedSymm16, Rsqrt3dTest<DataType::QSymmS16>)
// Permute
+ARMNN_AUTO_TEST_CASE(SimplePermuteBFloat16, SimplePermuteTest<DataType::BFloat16>)
+ARMNN_AUTO_TEST_CASE(PermuteBFloat16ValueSet1Test, PermuteValueSet1Test<DataType::BFloat16>)
+ARMNN_AUTO_TEST_CASE(PermuteBFloat16ValueSet2Test, PermuteValueSet2Test<DataType::BFloat16>)
+ARMNN_AUTO_TEST_CASE(PermuteBFloat16ValueSet3Test, PermuteValueSet3Test<DataType::BFloat16>)
ARMNN_AUTO_TEST_CASE(SimplePermuteFloat32, SimplePermuteTest<DataType::Float32>)
ARMNN_AUTO_TEST_CASE(PermuteFloat32ValueSet1Test, PermuteValueSet1Test<DataType::Float32>)
ARMNN_AUTO_TEST_CASE(PermuteFloat32ValueSet2Test, PermuteValueSet2Test<DataType::Float32>)
@@ -1465,6 +1523,10 @@ ARMNN_AUTO_TEST_CASE(Slice2dInt16, Slice2dInt16Test)
ARMNN_AUTO_TEST_CASE(Slice1dInt16, Slice1dInt16Test)
// Transpose
+ARMNN_AUTO_TEST_CASE(SimpleTransposeBFloat16, SimpleTransposeTest<DataType::BFloat16>)
+ARMNN_AUTO_TEST_CASE(TransposeBFloat16ValueSet1Test, TransposeValueSet1Test<DataType::BFloat16>)
+ARMNN_AUTO_TEST_CASE(TransposeBFloat16ValueSet2Test, TransposeValueSet2Test<DataType::BFloat16>)
+ARMNN_AUTO_TEST_CASE(TransposeBFloat16ValueSet3Test, TransposeValueSet3Test<DataType::BFloat16>)
ARMNN_AUTO_TEST_CASE(SimpleTransposeFloat32, SimpleTransposeTest<DataType::Float32>)
ARMNN_AUTO_TEST_CASE(TransposeFloat32ValueSet1Test, TransposeValueSet1Test<DataType::Float32>)
ARMNN_AUTO_TEST_CASE(TransposeFloat32ValueSet2Test, TransposeValueSet2Test<DataType::Float32>)
diff --git a/src/backends/reference/workloads/Pad.cpp b/src/backends/reference/workloads/Pad.cpp
index 9fedb44f96..ffdd469609 100644
--- a/src/backends/reference/workloads/Pad.cpp
+++ b/src/backends/reference/workloads/Pad.cpp
@@ -152,6 +152,13 @@ void Pad(const TensorInfo& inputInfo,
}
}
+template void Pad<BFloat16>(const TensorInfo& inputInfo,
+ const TensorInfo& outputInfo,
+ std::vector<std::pair<unsigned int, unsigned int>> m_PadList,
+ const BFloat16* inputData,
+ BFloat16* outData,
+ const float padValue);
+
template void Pad<float>(const TensorInfo& inputInfo,
const TensorInfo& outputInfo,
std::vector<std::pair<unsigned int, unsigned int>> m_PadList,
diff --git a/src/backends/reference/workloads/RefPadWorkload.cpp b/src/backends/reference/workloads/RefPadWorkload.cpp
index 356f6b1172..777682d70c 100644
--- a/src/backends/reference/workloads/RefPadWorkload.cpp
+++ b/src/backends/reference/workloads/RefPadWorkload.cpp
@@ -33,6 +33,7 @@ void RefPadWorkload<DataType>::Execute() const
Pad(inputInfo, outputInfo, m_Data.m_Parameters.m_PadList, inputData, outputData, m_Data.m_Parameters.m_PadValue);
}
+template class RefPadWorkload<DataType::BFloat16>;
template class RefPadWorkload<DataType::Float32>;
template class RefPadWorkload<DataType::Float16>;
template class RefPadWorkload<DataType::QAsymmU8>;
diff --git a/src/backends/reference/workloads/RefPadWorkload.hpp b/src/backends/reference/workloads/RefPadWorkload.hpp
index 28fb55386e..5134ac8bff 100644
--- a/src/backends/reference/workloads/RefPadWorkload.hpp
+++ b/src/backends/reference/workloads/RefPadWorkload.hpp
@@ -30,6 +30,7 @@ public:
void Execute() const override;
};
+using RefPadBFloat16Workload = RefPadWorkload<DataType::BFloat16>;
using RefPadFloat32Workload = RefPadWorkload<DataType::Float32>;
using RefPadFloat16Workload = RefPadWorkload<DataType::Float16>;
using RefPadQAsymm8Workload = RefPadWorkload<DataType::QAsymmU8>;
diff --git a/src/backends/reference/workloads/RefPermuteWorkload.cpp b/src/backends/reference/workloads/RefPermuteWorkload.cpp
index d0e1431ffd..5751ed80a3 100644
--- a/src/backends/reference/workloads/RefPermuteWorkload.cpp
+++ b/src/backends/reference/workloads/RefPermuteWorkload.cpp
@@ -28,6 +28,7 @@ void RefPermuteWorkload<DataType>::Execute() const
src->Map(), dst->Map(), sizeof(T));
}
+template class RefPermuteWorkload<DataType::BFloat16>;
template class RefPermuteWorkload<DataType::Float16>;
template class RefPermuteWorkload<DataType::Float32>;
template class RefPermuteWorkload<DataType::QAsymmU8>;
diff --git a/src/backends/reference/workloads/RefPermuteWorkload.hpp b/src/backends/reference/workloads/RefPermuteWorkload.hpp
index 00a33850aa..a8d308e47c 100644
--- a/src/backends/reference/workloads/RefPermuteWorkload.hpp
+++ b/src/backends/reference/workloads/RefPermuteWorkload.hpp
@@ -27,6 +27,7 @@ public:
void Execute() const override;
};
+using RefPermuteBFloat16Workload = RefPermuteWorkload<DataType::BFloat16>;
using RefPermuteFloat16Workload = RefPermuteWorkload<DataType::Float16>;
using RefPermuteFloat32Workload = RefPermuteWorkload<DataType::Float32>;
using RefPermuteQAsymm8Workload = RefPermuteWorkload<DataType::QAsymmU8>;
diff --git a/src/backends/reference/workloads/RefTransposeWorkload.cpp b/src/backends/reference/workloads/RefTransposeWorkload.cpp
index 6bdfb2111d..242668b6b1 100644
--- a/src/backends/reference/workloads/RefTransposeWorkload.cpp
+++ b/src/backends/reference/workloads/RefTransposeWorkload.cpp
@@ -27,6 +27,7 @@ void RefTransposeWorkload<DataType>::Execute() const
armnnUtils::Transpose(GetTensorInfo(src).GetShape(), mappings, src->Map(), dst->Map(), sizeof(T));
}
+template class RefTransposeWorkload<DataType::BFloat16>;
template class RefTransposeWorkload<DataType::Float16>;
template class RefTransposeWorkload<DataType::Float32>;
template class RefTransposeWorkload<DataType::QAsymmU8>;
diff --git a/src/backends/reference/workloads/RefTransposeWorkload.hpp b/src/backends/reference/workloads/RefTransposeWorkload.hpp
index 4b1c3d303b..dcfe618b75 100644
--- a/src/backends/reference/workloads/RefTransposeWorkload.hpp
+++ b/src/backends/reference/workloads/RefTransposeWorkload.hpp
@@ -27,6 +27,7 @@ public:
void Execute() const override;
};
+using RefTransposeBFloat16Workload = RefTransposeWorkload<DataType::BFloat16>;
using RefTransposeFloat16Workload = RefTransposeWorkload<DataType::Float16>;
using RefTransposeFloat32Workload = RefTransposeWorkload<DataType::Float32>;
using RefTransposeQAsymm8Workload = RefTransposeWorkload<DataType::QAsymmU8>;