aboutsummaryrefslogtreecommitdiff
path: root/src/backends/tosaCommon/operatorMappings/ReluOperator.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/tosaCommon/operatorMappings/ReluOperator.cpp')
-rw-r--r--src/backends/tosaCommon/operatorMappings/ReluOperator.cpp53
1 files changed, 44 insertions, 9 deletions
diff --git a/src/backends/tosaCommon/operatorMappings/ReluOperator.cpp b/src/backends/tosaCommon/operatorMappings/ReluOperator.cpp
index bd1a59670e..541b39cd8d 100644
--- a/src/backends/tosaCommon/operatorMappings/ReluOperator.cpp
+++ b/src/backends/tosaCommon/operatorMappings/ReluOperator.cpp
@@ -17,7 +17,7 @@
TosaSerializationBasicBlock* ConvertReluToTosaOperator(const Layer* layer,
const std::vector<const TensorInfo*>& inputs,
const std::vector<const TensorInfo*>& outputs,
- const ActivationDescriptor*)
+ const ActivationDescriptor* desc)
{
if (inputs.size() != 1)
{
@@ -31,7 +31,36 @@ TosaSerializationBasicBlock* ConvertReluToTosaOperator(const Layer* layer,
std::string inputName = std::string("input_");
std::string outputName = std::string("output0_");
- std::string blockName = std::string("Op_RELU_block_") + GetUniqueTosaMappingID();
+ std::string blockName = "";
+
+ int32_t clamp_min = 0;
+ int32_t clamp_max = 0;
+ float float_max = 0.0f;
+ switch (desc->m_Function)
+ {
+ case ActivationFunction::ReLu:
+ {
+ clamp_max = std::numeric_limits<int32_t>::max();
+ float_max = std::numeric_limits<float>::max();
+ blockName = std::string("Op_RELU_block_") + GetUniqueTosaMappingID();
+ break;
+ }
+ case ActivationFunction::BoundedReLu:
+ {
+ clamp_max = static_cast<int32_t>(desc->m_A);
+ float_max = desc->m_A;
+ blockName = std::string("Op_BOUNDED_RELU_block_") + GetUniqueTosaMappingID();
+ break;
+ }
+ case ActivationFunction::LeakyReLu:
+ {
+ throw Exception("LeakyRelu TOSA mappings are performed in ConvertLeakyReluToTosaOperator().");
+ }
+ default:
+ {
+ throw Exception("Activation function is not supported in ConvertReluToTosaOperator().");
+ }
+ }
// If a layer is present then the block will be used for execution, so input and output names need to be determined
// using the previous and following layers so the graph is connected correctly. For validation this doesn't matter.
@@ -60,8 +89,6 @@ TosaSerializationBasicBlock* ConvertReluToTosaOperator(const Layer* layer,
DType outputDType0 = ArmNNToDType(outputs[0]->GetDataType());
tensors.push_back(new TosaSerializationTensor(outputName, outputShape0, outputDType0, {}));
- int32_t clamp_min = 0;
- int32_t clamp_max = std::numeric_limits<int32_t>::max();
std::string clampInputNameStr = inputName;
if (inputDType0 == tosa::DType::DType_INT8 || inputDType0 == tosa::DType::DType_INT16)
{
@@ -72,18 +99,26 @@ TosaSerializationBasicBlock* ConvertReluToTosaOperator(const Layer* layer,
int32_t input_zp = inputs[0]->GetQuantizationOffset();
int32_t output_zp = outputs[0]->GetQuantizationOffset();
- clamp_min = outputs[0]->GetQuantizationOffset();
+ clamp_min = output_zp;
+
+ if (desc->m_Function == ActivationFunction::BoundedReLu)
+ {
+ clamp_max = static_cast<int32_t>(std::round(desc->m_A / outputs[0]->GetQuantizationScale())) + output_zp;
+ }
+
if (inputDType0 == tosa::DType::DType_INT8)
{
clamp_min =
clamp_min < std::numeric_limits<int8_t>::min() ? std::numeric_limits<int8_t>::min() : clamp_min;
- clamp_max = std::numeric_limits<int8_t>::max();
+ clamp_max =
+ clamp_max > std::numeric_limits<int8_t>::max() ? std::numeric_limits<int8_t>::max() : clamp_max;
}
else
{
clamp_min =
clamp_min < std::numeric_limits<int16_t>::min() ? std::numeric_limits<int16_t>::min() : clamp_min;
- clamp_max = std::numeric_limits<int16_t>::max();
+ clamp_max =
+ clamp_max > std::numeric_limits<int16_t>::max() ? std::numeric_limits<int16_t>::max() : clamp_max;
}
TosaSerializationOperator* rescaleOp = nullptr;
@@ -101,8 +136,8 @@ TosaSerializationBasicBlock* ConvertReluToTosaOperator(const Layer* layer,
inputDType0,
{}));
}
-
- TosaClampAttribute attribute(clamp_min, clamp_max, 0, std::numeric_limits<float>::max());
+
+ TosaClampAttribute attribute(clamp_min, clamp_max, 0, float_max);
auto* clamp_op = new TosaSerializationOperator(Op_CLAMP,
Attribute_ClampAttribute,
&attribute,