ArmNN
 24.05
TosaRescaleOperatorUtils.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <armnn/Exceptions.hpp>
7 
8 #pragma once
9 
10 inline void CreateRescaleTosaOperator(const std::string& inputName,
11  const std::string& outputName,
12  const std::vector<int32_t>& multipliers,
13  const std::vector<int32_t>& shifts,
14  int32_t input_zp,
15  int32_t output_zp,
16  bool double_round,
17  bool scale32,
18  bool per_channel,
19  TosaSerializationOperator** op)
20 {
21  if (!op)
22  {
23  throw armnn::Exception("CreateRescaleTosaOperator: nullptr op");
24  }
25 
26  TosaRescaleAttribute attribute(input_zp,
27  output_zp,
28  multipliers,
29  shifts,
30  scale32,
31  double_round,
32  per_channel,
33  false, // input_unsigned
34  false); // output_unsigned
35 
36  // op
37  *op = new TosaSerializationOperator(Op_RESCALE, Attribute_RescaleAttribute, &attribute, {inputName}, {outputName});
38  if (!(*op))
39  {
40  throw armnn::Exception("CreateRescaleTosaOperator: failed to created operator");
41  }
42 }
43 
44 inline void CreateRescaleTosaOperator(const std::string& inputName,
45  const std::string& outputName,
46  int32_t scale_multiplier,
47  int32_t scale_shift,
48  int32_t input_zp,
49  int32_t output_zp,
50  bool double_round,
51  bool scale32,
52  bool per_channel,
53  TosaSerializationOperator** op)
54 {
55  const std::vector<int32_t> multipliers{scale_multiplier};
56  const std::vector<int32_t> shifts{scale_shift};
57  CreateRescaleTosaOperator(inputName, outputName, multipliers, shifts,
58  input_zp, output_zp, double_round, scale32, per_channel, op);
59 }
60 
61 /// The following is taken from mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp in the LLVM project
62 /// From a scale value, generates multiplier and shift values where
63 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
64 /// multiplier = mantissa*2^shift for 32-bit scaling.
65 static void ComputeMultiplierAndShiftTosaScale32(double scale,
66  int32_t &multiplier,
67  int32_t &shift)
68 {
69  const double mantissa = std::frexp(scale, &shift);
70  auto shiftedM = std::round(mantissa * (int64_t(1) << 31));
71 
72  // Can't be greater than 1.0.
73  if (!(shiftedM <= (int64_t(1) << 31)))
74  {
75  throw armnn::Exception("Shifted mantissa exceeds 32 signed bits");
76  }
77 
78  if (shiftedM == (int64_t(1) << 31))
79  {
80  shiftedM /= 2;
81  shift++;
82  }
83 
84  // TOSA expects right shift to be positive, and embed (1 << 31) into right
85  // shift bits.
86  shift = (-shift) + 31;
87 
88  if (!(shiftedM <= std::numeric_limits<int32_t>::max()))
89  {
90  throw armnn::Exception("Shifted mantissa exceeds 32-bit signed output type");
91  }
92 
93  multiplier = static_cast<int32_t>(shiftedM);
94 
95  // Shifting tops out at 62 bits. Right shift to make 62 bits the max.
96  // The limit of 62 on shift allows the shift to be decomposed as
97  // two right shifts of 31.
98  if (shift > 62)
99  {
100  // Shifting the multiplier by more than 32-bits is unnecessary.
101  multiplier = multiplier >> std::min<int32_t>(31, shift - 62);
102  shift = 62;
103  }
104 }
105 
106 /// The following is taken from mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp in the LLVM project
107 /// From a scale value, generates multiplier and shift values where
108 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
109 /// multiplier = mantissa*2^shift for 16-bit scaling.
110 static void ComputeMultiplierAndShiftTosaScale16(double scale,
111  int32_t &multiplier,
112  int32_t &shift)
113 {
114  const double mantissa = std::frexp(scale, &shift);
115  auto shiftedM = std::round(mantissa * (int64_t(1) << 15));
116 
117  // Can't be greater than 1.0.
118  if (!(shiftedM <= (int64_t(1) << 15)))
119  {
120  throw armnn::Exception("Shifted mantissa exceeds 16 signed bits");
121  }
122 
123  if (shiftedM == (int64_t(1) << 15))
124  {
125  shiftedM /= 2;
126  shift++;
127  }
128 
129  // TOSA expects right shift to be positive and embed (1 << 15) into right
130  // shift bits.
131  shift = (-shift) + 15;
132 
133  if (!(shiftedM <= std::numeric_limits<int32_t>::max()))
134  {
135  throw armnn::Exception("Shifted mantissa exceeds 32-bit signed output type");
136  }
137 
138  multiplier = static_cast<int32_t>(shiftedM);
139 
140  // Shifting tops out at 62 bits. Right shift to make 62 bits the max.
141  // The limit of 62 on shift allows the shift to be decomposed as
142  // two right shifts of 31.
143  if (shift > 62)
144  {
145  // Shifting the multiplier by more than 31-bits is unnecessary.
146  multiplier = multiplier >> std::min<int32_t>(31, shift - 62);
147  shift = 62;
148  }
149 }
150 
151 inline void CreateRescaleTosaOperator(const std::string& inputName,
152  const std::string& outputName,
153  double scale,
154  int32_t input_zp,
155  int32_t output_zp,
156  bool double_round,
157  bool scale32,
158  TosaSerializationOperator** op)
159 {
160  int32_t multiplier;
161  int32_t shift;
162 
163  if (scale32)
164  {
165  ComputeMultiplierAndShiftTosaScale32(scale, multiplier, shift);
166  }
167  else
168  {
169  ComputeMultiplierAndShiftTosaScale16(scale, multiplier, shift);
170  }
171 
172  CreateRescaleTosaOperator(inputName, outputName, multiplier, shift,
173  input_zp, output_zp, double_round, scale32, false, op);
174 }
175 
176 inline void CreateRescaleTosaOperatorPerChannel(const std::string& inputName,
177  const std::string& outputName,
178  int32_t input_zp,
179  int32_t output_zp,
180  bool double_round,
181  bool scale32,
182  double input_scale,
183  double output_scale,
184  const std::vector<float>& weight_scales,
185  TosaSerializationOperator** op)
186 {
187  std::vector<int32_t> op_tensor_multipliers;
188  std::vector<int32_t> op_tensor_shifts;
189  op_tensor_multipliers.reserve(weight_scales.size());
190  op_tensor_shifts.reserve(weight_scales.size());
191 
192  for (const float& weight_scale : weight_scales)
193  {
194  double op_tensor_scale = (input_scale * weight_scale) / output_scale;
195  int32_t multiplier;
196  int32_t shift;
197 
198  if (scale32)
199  {
200  ComputeMultiplierAndShiftTosaScale32(op_tensor_scale, multiplier, shift);
201  }
202  else
203  {
204  ComputeMultiplierAndShiftTosaScale16(op_tensor_scale, multiplier, shift);
205  }
206 
207  op_tensor_multipliers.push_back(multiplier);
208  op_tensor_shifts.push_back(shift);
209  }
210 
211  CreateRescaleTosaOperator(inputName, outputName, op_tensor_multipliers, op_tensor_shifts,
212  input_zp, output_zp, double_round, scale32, true, op);
213 }
214 
215 inline void CreateFromInt32RescaleTosaOperator(const std::string& inputName,
216  const std::string& outputName,
217  double output_scale,
218  int32_t output_zp,
219  TosaSerializationOperator** op)
220 {
221  CreateRescaleTosaOperator(inputName, outputName, output_scale,
222  0, output_zp, true, true, op);
223 }
armnn::Exception
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
CreateRescaleTosaOperatorPerChannel
void CreateRescaleTosaOperatorPerChannel(const std::string &inputName, const std::string &outputName, int32_t input_zp, int32_t output_zp, bool double_round, bool scale32, double input_scale, double output_scale, const std::vector< float > &weight_scales, TosaSerializationOperator **op)
Definition: TosaRescaleOperatorUtils.hpp:176
Exceptions.hpp
CreateFromInt32RescaleTosaOperator
void CreateFromInt32RescaleTosaOperator(const std::string &inputName, const std::string &outputName, double output_scale, int32_t output_zp, TosaSerializationOperator **op)
Definition: TosaRescaleOperatorUtils.hpp:215
CreateRescaleTosaOperator
void CreateRescaleTosaOperator(const std::string &inputName, const std::string &outputName, const std::vector< int32_t > &multipliers, const std::vector< int32_t > &shifts, int32_t input_zp, int32_t output_zp, bool double_round, bool scale32, bool per_channel, TosaSerializationOperator **op)
Definition: TosaRescaleOperatorUtils.hpp:10