diff options
Diffstat (limited to 'src/backends')
8 files changed, 66 insertions, 1 deletions
diff --git a/src/backends/backendsCommon/test/layerTests/DebugTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/DebugTestImpl.cpp index 149779b9ef..42fe4876ff 100644 --- a/src/backends/backendsCommon/test/layerTests/DebugTestImpl.cpp +++ b/src/backends/backendsCommon/test/layerTests/DebugTestImpl.cpp @@ -309,6 +309,34 @@ LayerTestResult<float, 1> Debug1dFloat32Test( return Debug1dTest<armnn::DataType::Float32>(workloadFactory, memoryManager); } +LayerTestResult<armnn::BFloat16, 4> Debug4dBFloat16Test( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) +{ + return Debug4dTest<armnn::DataType::BFloat16>(workloadFactory, memoryManager); +} + +LayerTestResult<armnn::BFloat16, 3> Debug3dBFloat16Test( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) +{ + return Debug3dTest<armnn::DataType::BFloat16>(workloadFactory, memoryManager); +} + +LayerTestResult<armnn::BFloat16, 2> Debug2dBFloat16Test( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) +{ + return Debug2dTest<armnn::DataType::BFloat16>(workloadFactory, memoryManager); +} + +LayerTestResult<armnn::BFloat16, 1> Debug1dBFloat16Test( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) +{ + return Debug1dTest<armnn::DataType::BFloat16>(workloadFactory, memoryManager); +} + LayerTestResult<uint8_t, 4> Debug4dUint8Test( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) diff --git a/src/backends/backendsCommon/test/layerTests/DebugTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/DebugTestImpl.hpp index 7582663a85..cf4b237d27 100644 --- a/src/backends/backendsCommon/test/layerTests/DebugTestImpl.hpp +++ b/src/backends/backendsCommon/test/layerTests/DebugTestImpl.hpp @@ -7,6 +7,8 @@ #include "LayerTestResult.hpp" +#include <BFloat16.hpp> + #include <armnn/backends/IBackendInternal.hpp> #include <backendsCommon/WorkloadFactory.hpp> @@ -26,6 +28,22 @@ LayerTestResult<float, 1> Debug1dFloat32Test( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager); +LayerTestResult<armnn::BFloat16, 4> Debug4dBFloat16Test( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager); + +LayerTestResult<armnn::BFloat16, 3> Debug3dBFloat16Test( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager); + +LayerTestResult<armnn::BFloat16, 2> Debug2dBFloat16Test( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager); + +LayerTestResult<armnn::BFloat16, 1> Debug1dBFloat16Test( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager); + LayerTestResult<uint8_t, 4> Debug4dUint8Test( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager); diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 9dc576cac8..94128fe7cd 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -495,8 +495,9 @@ bool RefLayerSupport::IsDebugSupported(const TensorInfo& input, { bool supported = true; - std::array<DataType, 7> supportedTypes = + std::array<DataType, 8> supportedTypes = { + DataType::BFloat16, DataType::Float16, DataType::Float32, DataType::QAsymmU8, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 1d82421490..aebf19bf28 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -203,6 +203,10 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvolution2d(const Convolu std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor, const WorkloadInfo& info) const { + if (IsBFloat16(info)) + { + return std::make_unique<RefDebugBFloat16Workload>(descriptor, info); + } if (IsFloat16(info)) { return std::make_unique<RefDebugFloat16Workload>(descriptor, info); diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index a6bfe3575c..73b2a05e09 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -1417,6 +1417,11 @@ ARMNN_AUTO_TEST_CASE(Debug3dFloat32, Debug3dFloat32Test) ARMNN_AUTO_TEST_CASE(Debug2dFloat32, Debug2dFloat32Test) ARMNN_AUTO_TEST_CASE(Debug1dFloat32, Debug1dFloat32Test) +ARMNN_AUTO_TEST_CASE(Debug4dBFloat16, Debug4dBFloat16Test) +ARMNN_AUTO_TEST_CASE(Debug3dBFloat16, Debug3dBFloat16Test) +ARMNN_AUTO_TEST_CASE(Debug2dBFloat16, Debug2dBFloat16Test) +ARMNN_AUTO_TEST_CASE(Debug1dBFloat16, Debug1dBFloat16Test) + ARMNN_AUTO_TEST_CASE(Debug4dUint8, Debug4dUint8Test) ARMNN_AUTO_TEST_CASE(Debug3dUint8, Debug3dUint8Test) ARMNN_AUTO_TEST_CASE(Debug2dUint8, Debug2dUint8Test) diff --git a/src/backends/reference/workloads/Debug.cpp b/src/backends/reference/workloads/Debug.cpp index 49e9e02ffb..aadbc7613b 100644 --- a/src/backends/reference/workloads/Debug.cpp +++ b/src/backends/reference/workloads/Debug.cpp @@ -5,6 +5,7 @@ #include "Debug.hpp" +#include <BFloat16.hpp> #include <Half.hpp> #include <boost/numeric/conversion/cast.hpp> @@ -88,6 +89,12 @@ void Debug(const TensorInfo& inputInfo, std::cout << " }" << std::endl; } +template void Debug<BFloat16>(const TensorInfo& inputInfo, + const BFloat16* inputData, + LayerGuid guid, + const std::string& layerName, + unsigned int slotIndex); + template void Debug<Half>(const TensorInfo& inputInfo, const Half* inputData, LayerGuid guid, diff --git a/src/backends/reference/workloads/RefDebugWorkload.cpp b/src/backends/reference/workloads/RefDebugWorkload.cpp index af714a3ca7..72b03effca 100644 --- a/src/backends/reference/workloads/RefDebugWorkload.cpp +++ b/src/backends/reference/workloads/RefDebugWorkload.cpp @@ -44,6 +44,7 @@ void RefDebugWorkload<DataType>::RegisterDebugCallback(const DebugCallbackFuncti m_Callback = func; } +template class RefDebugWorkload<DataType::BFloat16>; template class RefDebugWorkload<DataType::Float16>; template class RefDebugWorkload<DataType::Float32>; template class RefDebugWorkload<DataType::QAsymmU8>; diff --git a/src/backends/reference/workloads/RefDebugWorkload.hpp b/src/backends/reference/workloads/RefDebugWorkload.hpp index 5a2a1cdf1b..1ccbcc590b 100644 --- a/src/backends/reference/workloads/RefDebugWorkload.hpp +++ b/src/backends/reference/workloads/RefDebugWorkload.hpp @@ -37,6 +37,7 @@ private: DebugCallbackFunction m_Callback; }; +using RefDebugBFloat16Workload = RefDebugWorkload<DataType::BFloat16>; using RefDebugFloat16Workload = RefDebugWorkload<DataType::Float16>; using RefDebugFloat32Workload = RefDebugWorkload<DataType::Float32>; using RefDebugQAsymmU8Workload = RefDebugWorkload<DataType::QAsymmU8>; |