diff options
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 7c81a8f757..d4a0a6e865 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -638,6 +638,7 @@ TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOpt m_ParserFunctions[tflite::BuiltinOperator_NEG] = &TfLiteParserImpl::ParseNeg; m_ParserFunctions[tflite::BuiltinOperator_PACK] = &TfLiteParserImpl::ParsePack; m_ParserFunctions[tflite::BuiltinOperator_PAD] = &TfLiteParserImpl::ParsePad; + m_ParserFunctions[tflite::BuiltinOperator_PRELU] = &TfLiteParserImpl::ParsePrelu; m_ParserFunctions[tflite::BuiltinOperator_QUANTIZE] = &TfLiteParserImpl::ParseQuantize; m_ParserFunctions[tflite::BuiltinOperator_RELU] = &TfLiteParserImpl::ParseRelu; m_ParserFunctions[tflite::BuiltinOperator_RELU6] = &TfLiteParserImpl::ParseRelu6; @@ -1939,6 +1940,57 @@ void TfLiteParserImpl::ParsePad(size_t subgraphIndex, size_t operatorIndex) RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]}); } +void TfLiteParserImpl::ParsePrelu(size_t subgraphIndex, size_t operatorIndex) +{ + CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); + + auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(inputs.size(), 2); + + auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + auto layerName = fmt::format("Prelu:{}:{}", subgraphIndex, operatorIndex); + + armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); + armnn::TensorInfo alphaTensorInfo = ToTensorInfo(inputs[1]); + armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0], true); + CheckMatchingQuantization(inputTensorInfo, outputTensorInfo, layerName, "Input 0", "Output 0"); + + IConnectableLayer* layer = m_Network->AddPreluLayer(layerName.c_str()); + ARMNN_ASSERT(layer != nullptr); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + if (IsConstTensor(inputs[1])) + { + armnn::IInputSlot* slot = &(layer->GetInputSlot(0)); + RegisterConsumerOfTensor(subgraphIndex, 0, slot); + + auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); + + auto alphaTensorAndData = CreateConstTensorNonPermuted(inputs[1], alphaTensorInfo); + std::string constLayerName = fmt::format("Constant:{}", inputs[1]->name); + IConnectableLayer* constLayer = + m_Network->AddConstantLayer(alphaTensorAndData, constLayerName.c_str()); + ARMNN_ASSERT(constLayer != nullptr); + + constLayer->GetOutputSlot(0).SetTensorInfo(alphaTensorInfo); + constLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1)); + RegisterOutputSlots(subgraphIndex, + VIRTUAL_OPERATOR_ID, + constLayer, + { inputTensorIndexes[1] }); + } + else + { + auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterInputSlots(subgraphIndex, operatorIndex, layer, inputTensorIndexes); + } + + auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes); +} + void TfLiteParserImpl::ParseQuantize(size_t subgraphIndex, size_t operatorIndex) { CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); |