diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-05-24 18:50:24 +0100 |
---|---|---|
committer | Jim Flynn <jim.flynn@arm.com> | 2021-05-25 13:09:40 +0000 |
commit | bfaee6b574301a54eab07a6021c39ae710977f7f (patch) | |
tree | 2283bebc47633d744880af8530bf1603ff131f0f /src/armnnTfLiteParser/TfLiteParser.cpp | |
parent | 37c430efaa85f84905cf96ace21f310339374053 (diff) | |
download | armnn-bfaee6b574301a54eab07a6021c39ae710977f7f.tar.gz |
IVGCVSW-3649 Add TfLite parser support for Prelu layer
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I3dedcc86efe1a67c709d9da636953e2fc400107b
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 7c81a8f757..d4a0a6e865 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -638,6 +638,7 @@ TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOpt m_ParserFunctions[tflite::BuiltinOperator_NEG] = &TfLiteParserImpl::ParseNeg; m_ParserFunctions[tflite::BuiltinOperator_PACK] = &TfLiteParserImpl::ParsePack; m_ParserFunctions[tflite::BuiltinOperator_PAD] = &TfLiteParserImpl::ParsePad; + m_ParserFunctions[tflite::BuiltinOperator_PRELU] = &TfLiteParserImpl::ParsePrelu; m_ParserFunctions[tflite::BuiltinOperator_QUANTIZE] = &TfLiteParserImpl::ParseQuantize; m_ParserFunctions[tflite::BuiltinOperator_RELU] = &TfLiteParserImpl::ParseRelu; m_ParserFunctions[tflite::BuiltinOperator_RELU6] = &TfLiteParserImpl::ParseRelu6; @@ -1939,6 +1940,57 @@ void TfLiteParserImpl::ParsePad(size_t subgraphIndex, size_t operatorIndex) RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]}); } +void TfLiteParserImpl::ParsePrelu(size_t subgraphIndex, size_t operatorIndex) +{ + CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); + + auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(inputs.size(), 2); + + auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + auto layerName = fmt::format("Prelu:{}:{}", subgraphIndex, operatorIndex); + + armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); + armnn::TensorInfo alphaTensorInfo = ToTensorInfo(inputs[1]); + armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0], true); + CheckMatchingQuantization(inputTensorInfo, outputTensorInfo, layerName, "Input 0", "Output 0"); + + IConnectableLayer* layer = m_Network->AddPreluLayer(layerName.c_str()); + ARMNN_ASSERT(layer != nullptr); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + if (IsConstTensor(inputs[1])) + { + armnn::IInputSlot* slot = &(layer->GetInputSlot(0)); + RegisterConsumerOfTensor(subgraphIndex, 0, slot); + + auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); + + auto alphaTensorAndData = CreateConstTensorNonPermuted(inputs[1], alphaTensorInfo); + std::string constLayerName = fmt::format("Constant:{}", inputs[1]->name); + IConnectableLayer* constLayer = + m_Network->AddConstantLayer(alphaTensorAndData, constLayerName.c_str()); + ARMNN_ASSERT(constLayer != nullptr); + + constLayer->GetOutputSlot(0).SetTensorInfo(alphaTensorInfo); + constLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1)); + RegisterOutputSlots(subgraphIndex, + VIRTUAL_OPERATOR_ID, + constLayer, + { inputTensorIndexes[1] }); + } + else + { + auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterInputSlots(subgraphIndex, operatorIndex, layer, inputTensorIndexes); + } + + auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes); +} + void TfLiteParserImpl::ParseQuantize(size_t subgraphIndex, size_t operatorIndex) { CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); |