ArmNN
 22.05
ConvertConstants.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "Optimization.hpp"
9 
13 
14 #include <BFloat16.hpp>
15 #include <Half.hpp>
16 
17 namespace armnn
18 {
19 namespace optimizations
20 {
21 
23 {
24  static void Func(std::shared_ptr<ConstTensorHandle>& handle)
25  {
26  const TensorInfo& info = handle->GetTensorInfo();
27 
28  if (info.GetDataType() == DataType::BFloat16)
29  {
30  std::vector<float> newValues(info.GetNumElements());
31 
33  info.GetNumElements(),
34  newValues.data());
35 
36  TensorInfo newInfo(info.GetShape(), DataType::Float32, 0.0f, 0, true);
37  ConstTensor newInput(newInfo, newValues);
38  handle.reset(new ScopedTensorHandle(newInput));
39  }
40  }
41 };
42 
44 {
45  static void Func(std::shared_ptr<ConstTensorHandle>& handle)
46  {
47  const TensorInfo& info = handle->GetTensorInfo();
48 
49  if (info.GetDataType() == DataType::Float16)
50  {
51  std::vector<float> newValues(info.GetNumElements());
52 
54  info.GetNumElements(),
55  newValues.data());
56 
57  TensorInfo newInfo(info.GetShape(), DataType::Float32, 0.0f, 0, true);
58  ConstTensor newInput(newInfo, newValues);
59  handle.reset(new ScopedTensorHandle(newInput));
60  }
61  }
62 };
63 
65 {
66  static void Func(std::shared_ptr<ConstTensorHandle>& handle)
67  {
68  const TensorInfo& info = handle->GetTensorInfo();
69 
70  if (info.GetDataType() == DataType::Float32)
71  {
72  std::vector<BFloat16> newValues(info.GetNumElements());
73 
75  info.GetNumElements(),
76  newValues.data());
77 
78  TensorInfo newInfo(info.GetShape(), DataType::BFloat16, 0.0f, 0, true);
79  ConstTensor newInput(newInfo, newValues);
80  handle.reset(new ScopedTensorHandle(newInput));
81  }
82  }
83 };
84 
86 {
87  static void Func(std::shared_ptr<ConstTensorHandle>& handle)
88  {
89  const TensorInfo& info = handle->GetTensorInfo();
90 
91  if (info.GetDataType() == DataType::Float32)
92  {
93  std::vector<Half> newValues(info.GetNumElements());
94 
95  armnnUtils::FloatingPointConverter::ConvertFloat32To16(handle->GetConstTensor<float>(),
96  info.GetNumElements(),
97  newValues.data());
98 
99  TensorInfo newInfo(info.GetShape(), DataType::Float16, 0.0f, 0, true);
100  ConstTensor newInput(newInfo, newValues);
101  handle.reset(new ScopedTensorHandle(newInput));
102  }
103  }
104 };
105 
106 template<typename Converter, typename Predicate>
108 {
109 public:
110  ConvertConstants() = default;
111  ConvertConstants(const ConvertConstants&) = default;
112  virtual ~ConvertConstants() = default;
113 
114  void Run(Graph& graph, Layer& layer) const override
115  {
116  IgnoreUnused(graph);
117  if (Predicate::Test(layer))
118  {
119  layer.OperateOnConstantTensors(Converter::Func);
120  }
121  }
122 protected:
123 };
124 
126 {
127  static bool Test(const Layer& layer)
128  {
129  return layer.GetDataType() == DataType::Float32;
130  }
131 };
132 
134 {
135  static bool Test(const Layer& layer)
136  {
137  return layer.GetDataType() == DataType::Float16;
138  }
139 };
140 
142 {
143  static bool Test(const Layer& layer)
144  {
145  return layer.GetDataType() == DataType::BFloat16;
146  }
147 };
148 
151 
154 
155 } //namespace optimizations
156 } //namespace armnn
static bool Test(const Layer &layer)
const TensorShape & GetShape() const
Definition: Tensor.hpp:191
static void Func(std::shared_ptr< ConstTensorHandle > &handle)
static bool Test(const Layer &layer)
static void ConvertBFloat16ToFloat32(const void *srcBFloat16Buffer, size_t numElements, float *dstFloat32Buffer)
void OperateOnConstantTensors(Op op)
Definition: Layer.hpp:304
static bool Test(const Layer &layer)
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)
static void Func(std::shared_ptr< ConstTensorHandle > &handle)
static void Func(std::shared_ptr< ConstTensorHandle > &handle)
DataType GetDataType() const
Definition: Tensor.hpp:198
static void ConvertFloat32To16(const float *srcFloat32Buffer, size_t numElements, void *dstFloat16Buffer)
Converts a buffer of FP32 values to FP16, and stores in the given dstFloat16Buffer.
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition: Tensor.hpp:327
static void ConvertFloat16To32(const void *srcFloat16Buffer, size_t numElements, float *dstFloat32Buffer)
void Run(Graph &graph, Layer &layer) const override
static void Func(std::shared_ptr< ConstTensorHandle > &handle)
static void ConvertFloat32ToBFloat16(const float *srcFloat32Buffer, size_t numElements, void *dstBFloat16Buffer)
DataType GetDataType() const
Definition: Layer.cpp:313
half_float::half Half
Definition: Half.hpp:18
unsigned int GetNumElements() const
Definition: Tensor.hpp:196