ArmNN
 21.02
ConvertFp32NetworkToBf16.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "NetworkUtils.hpp"
8 #include "Optimization.hpp"
9 
11 
12 namespace armnn
13 {
14 namespace optimizations
15 {
16 
17 template <typename LayerT>
18 inline LayerT* ConvertWeight(Layer* l)
19 {
20  LayerT* layer = PolymorphicDowncast<LayerT*>(l);
21  if ((layer->GetType() == LayerType::Convolution2d || layer->GetType() == LayerType::FullyConnected)
22  && layer->m_Weight)
23  {
24  const TensorInfo& info = layer->m_Weight->GetTensorInfo();
25 
26  if (info.GetDataType() == DataType::Float32)
27  {
28  std::vector<BFloat16> newValues(info.GetNumElements());
29 
30  armnnUtils::FloatingPointConverter::ConvertFloat32ToBFloat16(layer->m_Weight->template GetTensor<float>(),
31  info.GetNumElements(),
32  newValues.data());
33 
34  TensorInfo newInfo(info);
35  newInfo.SetDataType(DataType::BFloat16);
36  ConstTensor newInput(newInfo, newValues);
37  layer->m_Weight.reset(new ScopedCpuTensorHandle(newInput));
38  }
39  }
40  return layer;
41 }
42 
44 {
45 public:
46 
47  void Run(Graph& graph, Layer& layer) const
48  {
49  // Only convert Float32 To BFloat16 for the Input of Convolution2d layer and FullyConnected layer.
50  // And also convert weight data type from Float32 to Bfloat16.
51  // Do not convert bias data type.
52  if (layer.GetType() == LayerType::Convolution2d)
53  {
54  if (layer.GetDataType() == DataType::Float32)
55  {
57  ConvertWeight<Convolution2dLayer>(&layer);
58  }
59  }
60  else if (layer.GetType() == LayerType::FullyConnected)
61  {
62  if (layer.GetDataType() == DataType::Float32)
63  {
65  ConvertWeight<FullyConnectedLayer>(&layer);
66  }
67  }
68  }
69 
70 protected:
71  ConvertFp32NetworkToBf16Impl() = default;
73 };
74 
76 
77 } // namespace optimizations
78 } // namespace armnn
Copyright (c) 2021 ARM Limited and Contributors.
LayerT * ConvertWeight(Layer *l)
DataType GetDataType() const
Definition: Tensor.hpp:194
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition: Tensor.hpp:314
LayerType GetType() const override
Returns the armnn::LayerType of this layer.
Definition: Layer.hpp:265
static void ConvertFloat32ToBFloat16(const float *srcFloat32Buffer, size_t numElements, void *dstBFloat16Buffer)
std::vector< ConvertFp32ToBf16Layer * > InsertConvertFp32ToBf16LayersBefore(Graph &graph, Layer &layer, bool expectCorrectInputType)
DataType GetDataType() const
Definition: Layer.cpp:283
unsigned int GetNumElements() const
Definition: Tensor.hpp:192