aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp52
1 files changed, 52 insertions, 0 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 7c81a8f757..d4a0a6e865 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -638,6 +638,7 @@ TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOpt
m_ParserFunctions[tflite::BuiltinOperator_NEG] = &TfLiteParserImpl::ParseNeg;
m_ParserFunctions[tflite::BuiltinOperator_PACK] = &TfLiteParserImpl::ParsePack;
m_ParserFunctions[tflite::BuiltinOperator_PAD] = &TfLiteParserImpl::ParsePad;
+ m_ParserFunctions[tflite::BuiltinOperator_PRELU] = &TfLiteParserImpl::ParsePrelu;
m_ParserFunctions[tflite::BuiltinOperator_QUANTIZE] = &TfLiteParserImpl::ParseQuantize;
m_ParserFunctions[tflite::BuiltinOperator_RELU] = &TfLiteParserImpl::ParseRelu;
m_ParserFunctions[tflite::BuiltinOperator_RELU6] = &TfLiteParserImpl::ParseRelu6;
@@ -1939,6 +1940,57 @@ void TfLiteParserImpl::ParsePad(size_t subgraphIndex, size_t operatorIndex)
RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
}
+void TfLiteParserImpl::ParsePrelu(size_t subgraphIndex, size_t operatorIndex)
+{
+ CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
+
+ auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
+ CHECK_VALID_SIZE(inputs.size(), 2);
+
+ auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
+ CHECK_VALID_SIZE(outputs.size(), 1);
+
+ auto layerName = fmt::format("Prelu:{}:{}", subgraphIndex, operatorIndex);
+
+ armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
+ armnn::TensorInfo alphaTensorInfo = ToTensorInfo(inputs[1]);
+ armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0], true);
+ CheckMatchingQuantization(inputTensorInfo, outputTensorInfo, layerName, "Input 0", "Output 0");
+
+ IConnectableLayer* layer = m_Network->AddPreluLayer(layerName.c_str());
+ ARMNN_ASSERT(layer != nullptr);
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ if (IsConstTensor(inputs[1]))
+ {
+ armnn::IInputSlot* slot = &(layer->GetInputSlot(0));
+ RegisterConsumerOfTensor(subgraphIndex, 0, slot);
+
+ auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
+
+ auto alphaTensorAndData = CreateConstTensorNonPermuted(inputs[1], alphaTensorInfo);
+ std::string constLayerName = fmt::format("Constant:{}", inputs[1]->name);
+ IConnectableLayer* constLayer =
+ m_Network->AddConstantLayer(alphaTensorAndData, constLayerName.c_str());
+ ARMNN_ASSERT(constLayer != nullptr);
+
+ constLayer->GetOutputSlot(0).SetTensorInfo(alphaTensorInfo);
+ constLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1));
+ RegisterOutputSlots(subgraphIndex,
+ VIRTUAL_OPERATOR_ID,
+ constLayer,
+ { inputTensorIndexes[1] });
+ }
+ else
+ {
+ auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ RegisterInputSlots(subgraphIndex, operatorIndex, layer, inputTensorIndexes);
+ }
+
+ auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
+}
+
void TfLiteParserImpl::ParseQuantize(size_t subgraphIndex, size_t operatorIndex)
{
CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);