ArmNN
 20.05
ConvertFp32NetworkToBf16.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "NetworkUtils.hpp"
8 #include "Optimization.hpp"
9 
11 
12 namespace armnn
13 {
14 namespace optimizations
15 {
16 
17 template <typename LayerT>
18 inline LayerT* ConvertWeight(Layer* l)
19 {
20  LayerT* layer = PolymorphicDowncast<LayerT*>(l);
21  if ((layer->GetType() == LayerType::Convolution2d || layer->GetType() == LayerType::FullyConnected)
22  && layer->m_Weight)
23  {
24  const TensorInfo& info = layer->m_Weight->GetTensorInfo();
25 
26  if (info.GetDataType() == DataType::Float32)
27  {
28  std::vector<BFloat16> newValues(info.GetNumElements());
29 
30  armnnUtils::FloatingPointConverter::ConvertFloat32ToBFloat16(layer->m_Weight->template GetTensor<float>(),
31  info.GetNumElements(),
32  newValues.data());
33 
34  TensorInfo newInfo(info.GetShape(), DataType::BFloat16);
35  ConstTensor newInput(newInfo, newValues);
36  layer->m_Weight.reset(new ScopedCpuTensorHandle(newInput));
37  }
38  }
39  return layer;
40 }
41 
43 {
44 public:
45 
46  void Run(Graph& graph, Layer& layer) const
47  {
48  // Only convert Float32 To BFloat16 for the Input of Convolution2d layer and FullyConnected layer.
49  // And also convert weight data type from Float32 to Bfloat16.
50  // Do not convert bias data type.
51  if (layer.GetType() == LayerType::Convolution2d)
52  {
53  if (layer.GetDataType() == DataType::Float32)
54  {
56  ConvertWeight<Convolution2dLayer>(&layer);
57  }
58  }
59  else if (layer.GetType() == LayerType::FullyConnected)
60  {
61  if (layer.GetDataType() == DataType::Float32)
62  {
64  ConvertWeight<FullyConnectedLayer>(&layer);
65  }
66  }
67  }
68 
69 protected:
70  ConvertFp32NetworkToBf16Impl() = default;
72 };
73 
75 
76 } // namespace optimizations
77 } // namespace armnn
const TensorShape & GetShape() const
Definition: Tensor.hpp:88
Copyright (c) 2020 ARM Limited.
LayerT * ConvertWeight(Layer *l)
DataType GetDataType() const
Definition: Tensor.hpp:95
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition: Tensor.hpp:199
static void ConvertFloat32ToBFloat16(const float *srcFloat32Buffer, size_t numElements, void *dstBFloat16Buffer)
std::vector< ConvertFp32ToBf16Layer * > InsertConvertFp32ToBf16LayersBefore(Graph &graph, Layer &layer, bool expectCorrectInputType)
DataType GetDataType() const
Definition: Layer.cpp:274
LayerType GetType() const
Definition: Layer.hpp:259
unsigned int GetNumElements() const
Definition: Tensor.hpp:93