ArmNN
 21.02
NetworkQuantizationScheme.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/Types.hpp>
9 
10 #include <cmath>
11 #include <algorithm>
12 
13 namespace armnn
14 {
15 
16 using OffsetScalePair = std::pair<float, int>;
17 
19 {
20  virtual OffsetScalePair ComputeScheme(double min, double max) const = 0;
21 
22  virtual int NumBits() const = 0;
23 
24  virtual DataType GetDataType() const = 0;
25 
26  virtual ~IQuantizationScheme() {}
27 };
28 
30 {
31  OffsetScalePair ComputeScheme(double min, double max) const override
32  {
33  if (min > max)
34  {
35  throw InvalidArgumentException("min > max will result in invalid quantization.");
36  }
37 
38  double highest = (1 << NumBits()) - 1;
39 
40  min = std::min(0.0, min); // min <= 0.0
41  max = std::max(0.0, max); // max >= 0.0
42 
43  // To avoid dividing by zero when quantizing a zero filled tensor
44  if (min == 0.0 && max == 0.0)
45  {
46  max = 1.0;
47  }
48 
49  // Assumes quantization range [0-highest]
50  double scale = (max-min) / highest;
51  double offset = -min / scale;
52 
53  // Clamp offset [0-highest]
54  offset = std::max(0.0, std::min(highest, offset));
55 
56  return std::make_pair(static_cast<float>(scale), static_cast<int>(std::round(offset)));
57  }
58 
59  int NumBits() const override { return 8; }
60 
61  DataType GetDataType() const override { return DataType::QAsymmU8; }
62 };
63 
65 {
66  OffsetScalePair ComputeScheme(double min, double max) const override
67  {
68  if (min > max)
69  {
70  throw InvalidArgumentException("min > max will result in invalid quantization.");
71  }
72 
73  double highest = (1 << NumBits()) - 1;
74 
75  min = std::min(0.0, min); // min <= 0.0
76  max = std::max(0.0, max); // max >= 0.0
77 
78  // To avoid dividing by zero when quantizing a zero filled tensor
79  if (min == 0.0 && max == 0.0)
80  {
81  max = 1.0;
82  }
83 
84  // Assumes quantization range [0-255]
85  double scale = (max-min) / highest ;
86  double offset = - min / scale;
87 
88  //Clamp 0 to Highest
89  offset = std::max(0.0, std::min(highest, offset));
90 
91  //-128 on offset to cast to signed range
92  return std::make_pair(static_cast<float>(scale), static_cast<int>(std::round(offset)-128));
93  }
94 
95  int NumBits() const override { return 8; }
96 
97  DataType GetDataType() const override { return DataType::QAsymmS8; }
98 };
99 
101 {
102  OffsetScalePair ComputeScheme(double min, double max) const override
103  {
104  if (min > max)
105  {
106  throw InvalidArgumentException("min > max will result in invalid quantization.");
107  }
108 
109  // To avoid dividing by zero when quantizing a zero filled tensor
110  if (min == 0.0 && max == 0.0)
111  {
112  max = 1.0;
113  }
114 
115  double highest = (1 << (NumBits()-1)) - 1; // (numbits-1) accounts for the sign bit
116 
117  double extent = std::max(std::abs(min), std::abs(max));
118  double scale = extent / highest;
119 
120  return std::make_pair(static_cast<float>(scale), 0);
121  }
122 
123  int NumBits() const override { return 8; }
124 
125  DataType GetDataType() const override { return DataType::QSymmS8; }
126 };
127 
129 {
130  OffsetScalePair ComputeScheme(double min, double max) const override
131  {
132  if (min > max)
133  {
134  throw InvalidArgumentException("min > max will result in invalid quantization.");
135  }
136 
137  // To avoid dividing by zero when quantizing a zero filled tensor
138  if (min == 0.0 && max == 0.0)
139  {
140  max = 1.0;
141  }
142 
143  double highest = (1 << (NumBits()-1)) - 1; // (numbits-1) accounts for the sign bit
144 
145  double extent = std::max(std::abs(min), std::abs(max));
146  double scale = extent / highest;
147 
148  return std::make_pair(static_cast<float>(scale), 0);
149 
150  }
151 
152  int NumBits() const override { return 16; }
153 
154  DataType GetDataType() const override { return DataType::QSymmS16; }
155 };
156 
157 } // namespace armnn
std::pair< float, int > OffsetScalePair
Copyright (c) 2021 ARM Limited and Contributors.
virtual OffsetScalePair ComputeScheme(double min, double max) const =0
virtual int NumBits() const =0
DataType
Definition: Types.hpp:32
OffsetScalePair ComputeScheme(double min, double max) const override
virtual DataType GetDataType() const =0
OffsetScalePair ComputeScheme(double min, double max) const override
OffsetScalePair ComputeScheme(double min, double max) const override
OffsetScalePair ComputeScheme(double min, double max) const override