ArmNN
 24.02
TosaRescaleOperatorUtils.hpp File Reference
Include dependency graph for TosaRescaleOperatorUtils.hpp:
This graph shows which files directly or indirectly include this file:

Go to the source code of this file.

Functions

void CreateRescaleTosaOperator (const std::string &inputName, const std::string &outputName, DType output_type, const std::vector< int32_t > &shape, int32_t scale_multiplier, int32_t scale_shift, int32_t input_zp, int32_t output_zp, bool double_round, bool scale32, TosaSerializationOperator **op, TosaSerializationTensor **tensor)
 
void CreateRescaleTosaOperator (const std::string &inputName, const std::string &outputName, DType output_type, const std::vector< int32_t > &shape, double scale, int32_t input_zp, int32_t output_zp, bool double_round, bool scale32, TosaSerializationOperator **op, TosaSerializationTensor **tensor)
 
void CreateFromInt32RescaleTosaOperator (const std::string &inputName, const std::string &outputName, DType output_type, const std::vector< int32_t > &shape, double output_scale, int32_t output_zp, TosaSerializationOperator **op, TosaSerializationTensor **tensor)
 

Function Documentation

◆ CreateFromInt32RescaleTosaOperator()

void CreateFromInt32RescaleTosaOperator ( const std::string &  inputName,
const std::string &  outputName,
DType  output_type,
const std::vector< int32_t > &  shape,
double  output_scale,
int32_t  output_zp,
TosaSerializationOperator **  op,
TosaSerializationTensor **  tensor 
)
inline

Definition at line 119 of file TosaRescaleOperatorUtils.hpp.

127 {
128  CreateRescaleTosaOperator(inputName, outputName, output_type, shape,
129  output_scale, 0, output_zp, true, true, op, tensor);
130 }

References CreateRescaleTosaOperator().

◆ CreateRescaleTosaOperator() [1/2]

void CreateRescaleTosaOperator ( const std::string &  inputName,
const std::string &  outputName,
DType  output_type,
const std::vector< int32_t > &  shape,
double  scale,
int32_t  input_zp,
int32_t  output_zp,
bool  double_round,
bool  scale32,
TosaSerializationOperator **  op,
TosaSerializationTensor **  tensor 
)
inline

Definition at line 55 of file TosaRescaleOperatorUtils.hpp.

66 {
67  // The code that follows is based on the behaviour specified in
68  // https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
69 
70  auto GetScaleParams = [](double scale, double& m, int32_t& n)
71  {
72  m = 0;
73  n = 0;
74 
75  double lastErr = 1e06;
76 
77  const int32_t numExponents = 62;
78  const double start = 1.0;
79  const double end = 2.0;
80 
81  // Slow iterative approach but running in Reference only
82  for (int32_t i = 0; i < numExponents; ++i)
83  {
84  double exp = 1.0 / (1 << i);
85  double currentM = scale / exp; // Find current m given value = currentM * exp
86  if ((currentM >= start) && (currentM < end))
87  {
88  double value = currentM * exp;
89  double err = std::abs(scale - value);
90  if (err < lastErr)
91  {
92  // Take the m, n that minimize the error
93  n = i;
94  m = currentM;
95  lastErr = err;
96  }
97  }
98  }
99  };
100 
101  auto GetMultiplierShiftByScale = [GetScaleParams](bool scale32, double scale, int32_t& multiplier, int32_t& shift)
102  {
103  double m = 0;
104  int32_t n = 0;
105 
106  GetScaleParams(scale, m, n);
107 
108  multiplier = (scale32) ? (1 << 30) * static_cast<int32_t>(m) : (1 << 14) * static_cast<int32_t>(m);
109  shift = (scale32) ? (30 + n) : (14 + n);
110  };
111 
112  int32_t multiplier;
113  int32_t shift;
114  GetMultiplierShiftByScale(scale32, scale, multiplier, shift);
115  CreateRescaleTosaOperator(inputName, outputName, output_type, shape, multiplier, shift,
116  input_zp, output_zp, double_round, scale32, op, tensor);
117 }

References CreateRescaleTosaOperator().

◆ CreateRescaleTosaOperator() [2/2]

void CreateRescaleTosaOperator ( const std::string &  inputName,
const std::string &  outputName,
DType  output_type,
const std::vector< int32_t > &  shape,
int32_t  scale_multiplier,
int32_t  scale_shift,
int32_t  input_zp,
int32_t  output_zp,
bool  double_round,
bool  scale32,
TosaSerializationOperator **  op,
TosaSerializationTensor **  tensor 
)
inline

Definition at line 10 of file TosaRescaleOperatorUtils.hpp.

22 {
23  if (!op)
24  {
25  throw armnn::Exception("CreateRescaleTosaOperator: nullptr op");
26  }
27 
28  std::vector<int32_t> multipliers{scale_multiplier};
29  std::vector<int32_t> shifts{scale_shift};
30  TosaRescaleAttribute attribute(input_zp,
31  output_zp,
32  multipliers,
33  shifts,
34  scale32,
35  double_round,
36  false);
37 
38  // op
39  *op = new TosaSerializationOperator(Op_RESCALE, Attribute_RescaleAttribute, &attribute, {inputName}, {outputName});
40  if (!(*op))
41  {
42  throw armnn::Exception("CreateRescaleTosaOperator: failed to created operator");
43  }
44  if (tensor != nullptr)
45  {
46  // tensor
47  *tensor = new TosaSerializationTensor(outputName, shape, output_type, {});
48  if (! (*tensor))
49  {
50  throw armnn::Exception("CreateRescaleTosaOperator: failed to created tensor");
51  }
52  }
53 }

Referenced by CreateFromInt32RescaleTosaOperator(), and CreateRescaleTosaOperator().

CreateRescaleTosaOperator
void CreateRescaleTosaOperator(const std::string &inputName, const std::string &outputName, DType output_type, const std::vector< int32_t > &shape, int32_t scale_multiplier, int32_t scale_shift, int32_t input_zp, int32_t output_zp, bool double_round, bool scale32, TosaSerializationOperator **op, TosaSerializationTensor **tensor)
Definition: TosaRescaleOperatorUtils.hpp:10
armnn::Exception
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46