aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2020-02-28 12:45:21 +0000
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2020-03-03 15:59:39 +0000
commitc3bf6efb48a4540c8addcc02813c9381e4fceb1f (patch)
tree55fd9c3fa639ee32aa847dda2a525a7badee9e8d
parent490b7becb8029ead26423b0d62e631a929e55d6c (diff)
downloadarmnn-c3bf6efb48a4540c8addcc02813c9381e4fceb1f.tar.gz
IVGCVSW-4508 Add BFloat16 data type
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: Ie2fcf06d0bae9e2ef958e60ab9e8b121fdc7b590
-rw-r--r--CMakeLists.txt1
-rw-r--r--include/armnn/Types.hpp1
-rw-r--r--include/armnn/TypesUtils.hpp2
-rw-r--r--src/armnn/ResolveType.hpp7
-rw-r--r--src/armnn/test/UtilsTests.cpp19
-rw-r--r--src/armnnUtils/BFloat16.hpp13
6 files changed, 43 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index f55f391622..6ab89a4ed3 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -41,6 +41,7 @@ list(APPEND armnnUtils_sources
include/armnnUtils/FloatingPointConverter.hpp
include/armnnUtils/TensorUtils.hpp
include/armnnUtils/Transpose.hpp
+ src/armnnUtils/BFloat16.hpp
src/armnnUtils/Filesystem.hpp
src/armnnUtils/Filesystem.cpp
src/armnnUtils/Processes.hpp
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index d66a1fda90..29a0d4e364 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -40,6 +40,7 @@ enum class DataType
QuantizedSymm8PerAxis ARMNN_DEPRECATED_ENUM_MSG("Per Axis property inferred by number of scales in TensorInfo") = 6,
QSymmS8 = 7,
QAsymmS8 = 8,
+ BFloat16 = 9,
QuantisedAsymm8 ARMNN_DEPRECATED_ENUM_MSG("Use DataType::QAsymmU8 instead.") = QAsymmU8,
QuantisedSymm16 ARMNN_DEPRECATED_ENUM_MSG("Use DataType::QSymmS16 instead.") = QSymmS16
diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp
index 53a933b480..5065152c7a 100644
--- a/include/armnn/TypesUtils.hpp
+++ b/include/armnn/TypesUtils.hpp
@@ -116,6 +116,7 @@ constexpr unsigned int GetDataTypeSize(DataType dataType)
{
switch (dataType)
{
+ case DataType::BFloat16:
case DataType::Float16: return 2U;
case DataType::Float32:
case DataType::Signed32: return 4U;
@@ -179,6 +180,7 @@ constexpr const char* GetDataTypeName(DataType dataType)
case DataType::QSymmS16: return "QSymm16";
case DataType::Signed32: return "Signed32";
case DataType::Boolean: return "Boolean";
+ case DataType::BFloat16: return "BFloat16";
default:
return "Unknown";
diff --git a/src/armnn/ResolveType.hpp b/src/armnn/ResolveType.hpp
index 66309344db..e1bea42d3c 100644
--- a/src/armnn/ResolveType.hpp
+++ b/src/armnn/ResolveType.hpp
@@ -6,6 +6,7 @@
#pragma once
#include "armnn/Types.hpp"
+#include "BFloat16.hpp"
#include "Half.hpp"
namespace armnn
@@ -62,6 +63,12 @@ struct ResolveTypeImpl<DataType::Boolean>
using Type = uint8_t;
};
+template<>
+struct ResolveTypeImpl<DataType::BFloat16>
+{
+ using Type = BFloat16;
+};
+
template<DataType DT>
using ResolveType = typename ResolveTypeImpl<DT>::Type;
diff --git a/src/armnn/test/UtilsTests.cpp b/src/armnn/test/UtilsTests.cpp
index 4c371d6ed9..faf4480029 100644
--- a/src/armnn/test/UtilsTests.cpp
+++ b/src/armnn/test/UtilsTests.cpp
@@ -85,6 +85,25 @@ BOOST_AUTO_TEST_CASE(HalfType)
BOOST_CHECK((GetDataTypeName(armnn::DataType::Float16) == std::string("Float16")));
}
+BOOST_AUTO_TEST_CASE(BFloatType)
+{
+ armnn::BFloat16 a = 16256;
+
+ // Test BFloat16 type
+ BOOST_CHECK_EQUAL(sizeof(a), 2);
+
+ // Test utility function returns correct type.
+ using ResolvedType = armnn::ResolveType<armnn::DataType::BFloat16>;
+ constexpr bool isBFloat16Type = std::is_same<armnn::BFloat16, ResolvedType>::value;
+ BOOST_CHECK(isBFloat16Type);
+
+ //Test utility functions return correct size
+ BOOST_CHECK(GetDataTypeSize(armnn::DataType::BFloat16) == 2);
+
+ //Test utility functions return correct name
+ BOOST_CHECK((GetDataTypeName(armnn::DataType::BFloat16) == std::string("BFloat16")));
+}
+
BOOST_AUTO_TEST_CASE(GraphTopologicalSortSimpleTest)
{
std::map<int, std::vector<int>> graph;
diff --git a/src/armnnUtils/BFloat16.hpp b/src/armnnUtils/BFloat16.hpp
new file mode 100644
index 0000000000..bce45aa1ff
--- /dev/null
+++ b/src/armnnUtils/BFloat16.hpp
@@ -0,0 +1,13 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <stdint.h>
+
+namespace armnn
+{
+ using BFloat16 = uint16_t;
+} //namespace armnn