aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/ConvertConstants.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/optimizations/ConvertConstants.hpp')
-rw-r--r--src/armnn/optimizations/ConvertConstants.hpp54
1 files changed, 54 insertions, 0 deletions
diff --git a/src/armnn/optimizations/ConvertConstants.hpp b/src/armnn/optimizations/ConvertConstants.hpp
index 5e19c7bd05..f3ebcdf5d9 100644
--- a/src/armnn/optimizations/ConvertConstants.hpp
+++ b/src/armnn/optimizations/ConvertConstants.hpp
@@ -13,6 +13,7 @@
#include <armnn/utility/IgnoreUnused.hpp>
+#include <BFloat16.hpp>
#include <Half.hpp>
namespace armnn
@@ -20,6 +21,27 @@ namespace armnn
namespace optimizations
{
+struct BFloat16ToFloat32
+{
+ static void Func(std::unique_ptr<ScopedCpuTensorHandle>& handle)
+ {
+ const TensorInfo& info = handle->GetTensorInfo();
+
+ if (info.GetDataType() == DataType::BFloat16)
+ {
+ std::vector<float> newValues(info.GetNumElements());
+
+ armnnUtils::FloatingPointConverter::ConvertBFloat16ToFloat32(handle->GetTensor<BFloat16>(),
+ info.GetNumElements(),
+ newValues.data());
+
+ TensorInfo newInfo(info.GetShape(), DataType::Float32);
+ ConstTensor newInput(newInfo, newValues);
+ handle.reset(new ScopedCpuTensorHandle(newInput));
+ }
+ }
+};
+
struct Float16ToFloat32
{
static void Func(std::unique_ptr<ScopedCpuTensorHandle>& handle)
@@ -41,6 +63,27 @@ struct Float16ToFloat32
}
};
+struct Float32ToBFloat16
+{
+ static void Func(std::unique_ptr<ScopedCpuTensorHandle>& handle)
+ {
+ const TensorInfo& info = handle->GetTensorInfo();
+
+ if (info.GetDataType() == DataType::Float32)
+ {
+ std::vector<BFloat16> newValues(info.GetNumElements());
+
+ armnnUtils::FloatingPointConverter::ConvertFloat32ToBFloat16(handle->GetTensor<float>(),
+ info.GetNumElements(),
+ newValues.data());
+
+ TensorInfo newInfo(info.GetShape(), DataType::BFloat16);
+ ConstTensor newInput(newInfo, newValues);
+ handle.reset(new ScopedCpuTensorHandle(newInput));
+ }
+ }
+};
+
struct Float32ToFloat16
{
static void Func(std::unique_ptr<ScopedCpuTensorHandle>& handle)
@@ -97,6 +140,17 @@ struct IsFloat16Layer
}
};
+struct IsBFloat16Layer
+{
+ static bool Test(const Layer& layer)
+ {
+ return layer.GetDataType() == DataType::BFloat16;
+ }
+};
+
+using ConvertConstantsBFloatToFloat = ConvertConstants<BFloat16ToFloat32, IsFloat32Layer>;
+using ConvertConstantsFloatToBFloat = ConvertConstants<Float32ToBFloat16, IsBFloat16Layer>;
+
using ConvertConstantsHalfToFloat = ConvertConstants<Float16ToFloat32, IsFloat32Layer>;
using ConvertConstantsFloatToHalf = ConvertConstants<Float32ToFloat16, IsFloat16Layer>;