18 const std::vector<const TensorInfo*>& inputs,
19 const std::vector<const TensorInfo*>& outputs,
22 if (inputs.size() != 1)
24 throw armnn::Exception(
"ConvertActivationToTosaOperator: 1 input tensors required.");
27 if (outputs.size() != 1)
29 throw armnn::Exception(
"ConvertActivationToTosaOperator: 1 output tensor required.");
32 std::string inputName = std::string(
"input_");
35 std::string outputName = std::string(
"output0_");
46 std::vector<TosaSerializationTensor*> tensors;
51 std::vector<int32_t> inputShape0;
52 DType inputDType0 = DType::DType_UNKNOWN;
53 if(inputName.find(
"input_") != std::string::npos)
57 tensors.push_back(
new TosaSerializationTensor(inputName, inputShape0, inputDType0, {}));
61 DType outputDType0 =
ArmNNToDType(outputs[0]->GetDataType());
62 tensors.push_back(
new TosaSerializationTensor(outputName, outputShape0, outputDType0, {}));
64 #if TOSA_COMPAT_VERSION(0, 60, 0)
67 if (inputDType0 == DType::DType_FP32 ||
68 inputDType0 == DType::DType_FP16)
71 TosaSerializationOperator* alphaOp =
nullptr;
72 TosaSerializationTensor* alphaTensor =
nullptr;
73 CreateConstTosaOperator<float>(outputNameAlpha,
74 activationDescriptor->
m_A,
79 tensors.push_back(alphaTensor);
83 TosaMulAttribute mulAttribute(shift);
84 TosaSerializationOperator* mulOp =
new TosaSerializationOperator(Op_MUL,
85 Attribute_MulAttribute,
87 {inputName, outputNameAlpha},
89 tensors.push_back(
new TosaSerializationTensor(outputNameMul, inputShape0, inputDType0, {}));
91 TosaSerializationOperator* op =
nullptr;
92 if (activationDescriptor->
m_A <= 1.0)
94 op =
new TosaSerializationOperator(Op_MAXIMUM,
97 {inputName, outputNameMul},
102 op =
new TosaSerializationOperator(Op_MINIMUM,
105 {inputName, outputNameMul},
112 return new TosaSerializationBasicBlock(blockName,
114 {alphaOp, mulOp, op},
125 DType rescale_type = DType::DType_INT32;
126 float alpha = activationDescriptor->
m_A;
127 double scale_alpha = inputs[0]->GetQuantizationScale() * alpha / outputs[0]->GetQuantizationScale();
128 double scale_identity = inputs[0]->GetQuantizationScale() / outputs[0]->GetQuantizationScale();
129 int32_t input_zp = inputs[0]->GetQuantizationOffset();
130 int32_t output_zp = outputs[0]->GetQuantizationOffset();
135 TosaSerializationOperator* rescaleAlphaOp =
nullptr;
137 outputNameRescaleAlpha,
144 tensors.push_back(
new TosaSerializationTensor(outputNameRescaleAlpha,
150 TosaSerializationOperator* rescaleIdentityOp =
nullptr;
152 outputNameRescaleIdentity,
159 tensors.push_back(
new TosaSerializationTensor(outputNameRescaleIdentity,
174 TosaSerializationOperator* op =
nullptr;
177 op =
new TosaSerializationOperator(Op_MAXIMUM,
180 {outputNameRescaleAlpha, outputNameRescaleIdentity},
181 {outputNameRescaleMaxMin});
185 op =
new TosaSerializationOperator(Op_MINIMUM,
188 {outputNameRescaleAlpha, outputNameRescaleIdentity},
189 {outputNameRescaleMaxMin});
192 tensors.push_back(
new TosaSerializationTensor(outputNameRescaleMaxMin,
198 TosaSerializationOperator* rescaleOutputOp =
nullptr;
207 return new TosaSerializationBasicBlock(blockName,
209 {rescaleAlphaOp, rescaleIdentityOp, op, rescaleOutputOp},
219 TosaSerializationOperator* zeroOp =
nullptr;
220 TosaSerializationTensor* zeroTensor =
nullptr;
221 CreateConstTosaOperator<float>(outputNameZero,
227 tensors.push_back(zeroTensor);
230 TosaSerializationOperator* alphaOp =
nullptr;
231 TosaSerializationTensor* alphaTensor =
nullptr;
232 CreateConstTosaOperator<float>(outputNameAlpha,
233 activationDescriptor->
m_A,
238 tensors.push_back(alphaTensor);
242 TosaMulAttribute mulAttribute(shift);
243 TosaSerializationOperator* mulOp =
new TosaSerializationOperator(Op_MUL,
244 Attribute_MulAttribute,
246 {inputName, outputNameAlpha},
248 tensors.push_back(
new TosaSerializationTensor(outputNameMul, inputShape0, inputDType0, {}));
251 TosaSerializationOperator* geOp =
new TosaSerializationOperator(Op_GREATER_EQUAL,
254 {inputName, outputNameZero},
256 tensors.push_back(
new TosaSerializationTensor(outputNameGE, outputShape0, DType::DType_BOOL, {}));
259 TosaSerializationOperator* selOp =
new TosaSerializationOperator(Op_SELECT,
262 {outputNameGE, inputName, outputNameMul},
267 return new TosaSerializationBasicBlock(blockName,
269 {zeroOp, alphaOp, mulOp, geOp, selOp},