ArmNN
 24.02
TosaRescaleOperatorUtils.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <armnn/Exceptions.hpp>
7 
8 #pragma once
9 
10 inline void CreateRescaleTosaOperator(const std::string& inputName,
11  const std::string& outputName,
12  DType output_type,
13  const std::vector<int32_t>& shape,
14  int32_t scale_multiplier,
15  int32_t scale_shift,
16  int32_t input_zp,
17  int32_t output_zp,
18  bool double_round,
19  bool scale32,
20  TosaSerializationOperator** op,
21  TosaSerializationTensor** tensor)
22 {
23  if (!op)
24  {
25  throw armnn::Exception("CreateRescaleTosaOperator: nullptr op");
26  }
27 
28  std::vector<int32_t> multipliers{scale_multiplier};
29  std::vector<int32_t> shifts{scale_shift};
30  TosaRescaleAttribute attribute(input_zp,
31  output_zp,
32  multipliers,
33  shifts,
34  scale32,
35  double_round,
36  false);
37 
38  // op
39  *op = new TosaSerializationOperator(Op_RESCALE, Attribute_RescaleAttribute, &attribute, {inputName}, {outputName});
40  if (!(*op))
41  {
42  throw armnn::Exception("CreateRescaleTosaOperator: failed to created operator");
43  }
44  if (tensor != nullptr)
45  {
46  // tensor
47  *tensor = new TosaSerializationTensor(outputName, shape, output_type, {});
48  if (! (*tensor))
49  {
50  throw armnn::Exception("CreateRescaleTosaOperator: failed to created tensor");
51  }
52  }
53 }
54 
55 inline void CreateRescaleTosaOperator(const std::string& inputName,
56  const std::string& outputName,
57  DType output_type,
58  const std::vector<int32_t>& shape,
59  double scale,
60  int32_t input_zp,
61  int32_t output_zp,
62  bool double_round,
63  bool scale32,
64  TosaSerializationOperator** op,
65  TosaSerializationTensor** tensor)
66 {
67  // The code that follows is based on the behaviour specified in
68  // https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
69 
70  auto GetScaleParams = [](double scale, double& m, int32_t& n)
71  {
72  m = 0;
73  n = 0;
74 
75  double lastErr = 1e06;
76 
77  const int32_t numExponents = 62;
78  const double start = 1.0;
79  const double end = 2.0;
80 
81  // Slow iterative approach but running in Reference only
82  for (int32_t i = 0; i < numExponents; ++i)
83  {
84  double exp = 1.0 / (1 << i);
85  double currentM = scale / exp; // Find current m given value = currentM * exp
86  if ((currentM >= start) && (currentM < end))
87  {
88  double value = currentM * exp;
89  double err = std::abs(scale - value);
90  if (err < lastErr)
91  {
92  // Take the m, n that minimize the error
93  n = i;
94  m = currentM;
95  lastErr = err;
96  }
97  }
98  }
99  };
100 
101  auto GetMultiplierShiftByScale = [GetScaleParams](bool scale32, double scale, int32_t& multiplier, int32_t& shift)
102  {
103  double m = 0;
104  int32_t n = 0;
105 
106  GetScaleParams(scale, m, n);
107 
108  multiplier = (scale32) ? (1 << 30) * static_cast<int32_t>(m) : (1 << 14) * static_cast<int32_t>(m);
109  shift = (scale32) ? (30 + n) : (14 + n);
110  };
111 
112  int32_t multiplier;
113  int32_t shift;
114  GetMultiplierShiftByScale(scale32, scale, multiplier, shift);
115  CreateRescaleTosaOperator(inputName, outputName, output_type, shape, multiplier, shift,
116  input_zp, output_zp, double_round, scale32, op, tensor);
117 }
118 
119 inline void CreateFromInt32RescaleTosaOperator(const std::string& inputName,
120  const std::string& outputName,
121  DType output_type,
122  const std::vector<int32_t>& shape,
123  double output_scale,
124  int32_t output_zp,
125  TosaSerializationOperator** op,
126  TosaSerializationTensor** tensor)
127 {
128  CreateRescaleTosaOperator(inputName, outputName, output_type, shape,
129  output_scale, 0, output_zp, true, true, op, tensor);
130 }
CreateRescaleTosaOperator
void CreateRescaleTosaOperator(const std::string &inputName, const std::string &outputName, DType output_type, const std::vector< int32_t > &shape, int32_t scale_multiplier, int32_t scale_shift, int32_t input_zp, int32_t output_zp, bool double_round, bool scale32, TosaSerializationOperator **op, TosaSerializationTensor **tensor)
Definition: TosaRescaleOperatorUtils.hpp:10
CreateFromInt32RescaleTosaOperator
void CreateFromInt32RescaleTosaOperator(const std::string &inputName, const std::string &outputName, DType output_type, const std::vector< int32_t > &shape, double output_scale, int32_t output_zp, TosaSerializationOperator **op, TosaSerializationTensor **tensor)
Definition: TosaRescaleOperatorUtils.hpp:119
armnn::Exception
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
Exceptions.hpp