ArmNN
 22.05.01
ConvertFp32NetworkToBf16.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "NetworkUtils.hpp"
8 #include "Optimization.hpp"
9 
11 
12 namespace armnn
13 {
14 namespace optimizations
15 {
16 
17 template <typename LayerT>
18 inline LayerT* ConvertWeight(Layer* l)
19 {
20  LayerT* layer = PolymorphicDowncast<LayerT*>(l);
21  if ((layer->GetType() == LayerType::Convolution2d || layer->GetType() == LayerType::FullyConnected)
22  && layer->m_Weight)
23  {
24  const TensorInfo& info = layer->m_Weight->GetTensorInfo();
25 
26  if (info.GetDataType() == DataType::Float32)
27  {
28  std::vector<BFloat16> newValues(info.GetNumElements());
29 
31  layer->m_Weight->template GetConstTensor<float>(),
32  info.GetNumElements(),
33  newValues.data());
34 
35  TensorInfo newInfo(info);
36  newInfo.SetDataType(DataType::BFloat16);
37  ConstTensor newInput(newInfo, newValues);
38  layer->m_Weight.reset(new ScopedTensorHandle(newInput));
39  }
40  }
41  return layer;
42 }
43 
45 {
46 public:
47 
48  void Run(Graph& graph, Layer& layer) const
49  {
50  // Only convert Float32 To BFloat16 for the Input of Convolution2d layer and FullyConnected layer.
51  // And also convert weight data type from Float32 to Bfloat16.
52  // Do not convert bias data type.
53  if (layer.GetType() == LayerType::Convolution2d)
54  {
55  if (layer.GetDataType() == DataType::Float32)
56  {
58  ConvertWeight<Convolution2dLayer>(&layer);
59  }
60  }
61  else if (layer.GetType() == LayerType::FullyConnected)
62  {
63  if (layer.GetDataType() == DataType::Float32)
64  {
66  ConvertWeight<FullyConnectedLayer>(&layer);
67  }
68  }
69  }
70 
71 protected:
72  ConvertFp32NetworkToBf16Impl() = default;
74 };
75 
77 
78 } // namespace optimizations
79 } // namespace armnn
Copyright (c) 2021 ARM Limited and Contributors.
LayerT * ConvertWeight(Layer *l)
DataType GetDataType() const
Definition: Tensor.hpp:198
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition: Tensor.hpp:327
LayerType GetType() const override
Returns the armnn::LayerType of this layer.
Definition: Layer.hpp:271
static void ConvertFloat32ToBFloat16(const float *srcFloat32Buffer, size_t numElements, void *dstBFloat16Buffer)
std::vector< ConvertFp32ToBf16Layer * > InsertConvertFp32ToBf16LayersBefore(Graph &graph, Layer &layer, bool expectCorrectInputType)
DataType GetDataType() const
Definition: Layer.cpp:313
unsigned int GetNumElements() const
Definition: Tensor.hpp:196