diff options
-rw-r--r-- | CMakeLists.txt | 1 | ||||
-rw-r--r-- | include/armnn/Types.hpp | 1 | ||||
-rw-r--r-- | include/armnn/TypesUtils.hpp | 2 | ||||
-rw-r--r-- | src/armnn/ResolveType.hpp | 7 | ||||
-rw-r--r-- | src/armnn/test/UtilsTests.cpp | 19 | ||||
-rw-r--r-- | src/armnnUtils/BFloat16.hpp | 13 |
6 files changed, 43 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index f55f391622..6ab89a4ed3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -41,6 +41,7 @@ list(APPEND armnnUtils_sources include/armnnUtils/FloatingPointConverter.hpp include/armnnUtils/TensorUtils.hpp include/armnnUtils/Transpose.hpp + src/armnnUtils/BFloat16.hpp src/armnnUtils/Filesystem.hpp src/armnnUtils/Filesystem.cpp src/armnnUtils/Processes.hpp diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp index d66a1fda90..29a0d4e364 100644 --- a/include/armnn/Types.hpp +++ b/include/armnn/Types.hpp @@ -40,6 +40,7 @@ enum class DataType QuantizedSymm8PerAxis ARMNN_DEPRECATED_ENUM_MSG("Per Axis property inferred by number of scales in TensorInfo") = 6, QSymmS8 = 7, QAsymmS8 = 8, + BFloat16 = 9, QuantisedAsymm8 ARMNN_DEPRECATED_ENUM_MSG("Use DataType::QAsymmU8 instead.") = QAsymmU8, QuantisedSymm16 ARMNN_DEPRECATED_ENUM_MSG("Use DataType::QSymmS16 instead.") = QSymmS16 diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp index 53a933b480..5065152c7a 100644 --- a/include/armnn/TypesUtils.hpp +++ b/include/armnn/TypesUtils.hpp @@ -116,6 +116,7 @@ constexpr unsigned int GetDataTypeSize(DataType dataType) { switch (dataType) { + case DataType::BFloat16: case DataType::Float16: return 2U; case DataType::Float32: case DataType::Signed32: return 4U; @@ -179,6 +180,7 @@ constexpr const char* GetDataTypeName(DataType dataType) case DataType::QSymmS16: return "QSymm16"; case DataType::Signed32: return "Signed32"; case DataType::Boolean: return "Boolean"; + case DataType::BFloat16: return "BFloat16"; default: return "Unknown"; diff --git a/src/armnn/ResolveType.hpp b/src/armnn/ResolveType.hpp index 66309344db..e1bea42d3c 100644 --- a/src/armnn/ResolveType.hpp +++ b/src/armnn/ResolveType.hpp @@ -6,6 +6,7 @@ #pragma once #include "armnn/Types.hpp" +#include "BFloat16.hpp" #include "Half.hpp" namespace armnn @@ -62,6 +63,12 @@ struct ResolveTypeImpl<DataType::Boolean> using Type = uint8_t; }; +template<> +struct ResolveTypeImpl<DataType::BFloat16> +{ + using Type = BFloat16; +}; + template<DataType DT> using ResolveType = typename ResolveTypeImpl<DT>::Type; diff --git a/src/armnn/test/UtilsTests.cpp b/src/armnn/test/UtilsTests.cpp index 4c371d6ed9..faf4480029 100644 --- a/src/armnn/test/UtilsTests.cpp +++ b/src/armnn/test/UtilsTests.cpp @@ -85,6 +85,25 @@ BOOST_AUTO_TEST_CASE(HalfType) BOOST_CHECK((GetDataTypeName(armnn::DataType::Float16) == std::string("Float16"))); } +BOOST_AUTO_TEST_CASE(BFloatType) +{ + armnn::BFloat16 a = 16256; + + // Test BFloat16 type + BOOST_CHECK_EQUAL(sizeof(a), 2); + + // Test utility function returns correct type. + using ResolvedType = armnn::ResolveType<armnn::DataType::BFloat16>; + constexpr bool isBFloat16Type = std::is_same<armnn::BFloat16, ResolvedType>::value; + BOOST_CHECK(isBFloat16Type); + + //Test utility functions return correct size + BOOST_CHECK(GetDataTypeSize(armnn::DataType::BFloat16) == 2); + + //Test utility functions return correct name + BOOST_CHECK((GetDataTypeName(armnn::DataType::BFloat16) == std::string("BFloat16"))); +} + BOOST_AUTO_TEST_CASE(GraphTopologicalSortSimpleTest) { std::map<int, std::vector<int>> graph; diff --git a/src/armnnUtils/BFloat16.hpp b/src/armnnUtils/BFloat16.hpp new file mode 100644 index 0000000000..bce45aa1ff --- /dev/null +++ b/src/armnnUtils/BFloat16.hpp @@ -0,0 +1,13 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <stdint.h> + +namespace armnn +{ + using BFloat16 = uint16_t; +} //namespace armnn |