ArmNN
 21.02
ConvertConstants.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "Optimization.hpp"
9 
11 
13 
15 
16 #include <BFloat16.hpp>
17 #include <Half.hpp>
18 
19 namespace armnn
20 {
21 namespace optimizations
22 {
23 
25 {
26  static void Func(std::unique_ptr<ScopedCpuTensorHandle>& handle)
27  {
28  const TensorInfo& info = handle->GetTensorInfo();
29 
30  if (info.GetDataType() == DataType::BFloat16)
31  {
32  std::vector<float> newValues(info.GetNumElements());
33 
35  info.GetNumElements(),
36  newValues.data());
37 
38  TensorInfo newInfo(info.GetShape(), DataType::Float32);
39  ConstTensor newInput(newInfo, newValues);
40  handle.reset(new ScopedCpuTensorHandle(newInput));
41  }
42  }
43 };
44 
46 {
47  static void Func(std::unique_ptr<ScopedCpuTensorHandle>& handle)
48  {
49  const TensorInfo& info = handle->GetTensorInfo();
50 
51  if (info.GetDataType() == DataType::Float16)
52  {
53  std::vector<float> newValues(info.GetNumElements());
54 
56  info.GetNumElements(),
57  newValues.data());
58 
59  TensorInfo newInfo(info.GetShape(), DataType::Float32);
60  ConstTensor newInput(newInfo, newValues);
61  handle.reset(new ScopedCpuTensorHandle(newInput));
62  }
63  }
64 };
65 
67 {
68  static void Func(std::unique_ptr<ScopedCpuTensorHandle>& handle)
69  {
70  const TensorInfo& info = handle->GetTensorInfo();
71 
72  if (info.GetDataType() == DataType::Float32)
73  {
74  std::vector<BFloat16> newValues(info.GetNumElements());
75 
77  info.GetNumElements(),
78  newValues.data());
79 
80  TensorInfo newInfo(info.GetShape(), DataType::BFloat16);
81  ConstTensor newInput(newInfo, newValues);
82  handle.reset(new ScopedCpuTensorHandle(newInput));
83  }
84  }
85 };
86 
88 {
89  static void Func(std::unique_ptr<ScopedCpuTensorHandle>& handle)
90  {
91  const TensorInfo& info = handle->GetTensorInfo();
92 
93  if (info.GetDataType() == DataType::Float32)
94  {
95  std::vector<Half> newValues(info.GetNumElements());
96 
98  info.GetNumElements(),
99  newValues.data());
100 
101  TensorInfo newInfo(info.GetShape(), DataType::Float16);
102  ConstTensor newInput(newInfo, newValues);
103  handle.reset(new ScopedCpuTensorHandle(newInput));
104  }
105  }
106 };
107 
108 template<typename Converter, typename Predicate>
110 {
111 public:
112  ConvertConstants() = default;
113  ConvertConstants(const ConvertConstants&) = default;
114  virtual ~ConvertConstants() = default;
115 
116  void Run(Graph& graph, Layer& layer) const override
117  {
118  IgnoreUnused(graph);
119  if (Predicate::Test(layer))
120  {
121  layer.OperateOnConstantTensors(Converter::Func);
122  }
123  }
124 protected:
125 };
126 
128 {
129  static bool Test(const Layer& layer)
130  {
131  return layer.GetDataType() == DataType::Float32;
132  }
133 };
134 
136 {
137  static bool Test(const Layer& layer)
138  {
139  return layer.GetDataType() == DataType::Float16;
140  }
141 };
142 
144 {
145  static bool Test(const Layer& layer)
146  {
147  return layer.GetDataType() == DataType::BFloat16;
148  }
149 };
150 
153 
156 
157 } //namespace optimizations
158 } //namespace armnn
static bool Test(const Layer &layer)
const TensorShape & GetShape() const
Definition: Tensor.hpp:187
static bool Test(const Layer &layer)
static void ConvertBFloat16ToFloat32(const void *srcBFloat16Buffer, size_t numElements, float *dstFloat32Buffer)
void OperateOnConstantTensors(Op op)
Definition: Layer.hpp:298
static bool Test(const Layer &layer)
static void Func(std::unique_ptr< ScopedCpuTensorHandle > &handle)
static void Func(std::unique_ptr< ScopedCpuTensorHandle > &handle)
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)
static void Func(std::unique_ptr< ScopedCpuTensorHandle > &handle)
static void Func(std::unique_ptr< ScopedCpuTensorHandle > &handle)
DataType GetDataType() const
Definition: Tensor.hpp:194
static void ConvertFloat32To16(const float *srcFloat32Buffer, size_t numElements, void *dstFloat16Buffer)
Converts a buffer of FP32 values to FP16, and stores in the given dstFloat16Buffer.
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition: Tensor.hpp:314
static void ConvertFloat16To32(const void *srcFloat16Buffer, size_t numElements, float *dstFloat32Buffer)
void Run(Graph &graph, Layer &layer) const override
static void ConvertFloat32ToBFloat16(const float *srcFloat32Buffer, size_t numElements, void *dstBFloat16Buffer)
DataType GetDataType() const
Definition: Layer.cpp:283
half_float::half Half
Definition: Half.hpp:16
unsigned int GetNumElements() const
Definition: Tensor.hpp:192