diff options
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 30 |
1 files changed, 29 insertions, 1 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 479fc4f474..5dbb5ee503 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -710,6 +710,7 @@ TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOpt = &TfLiteParserImpl::ParseLocalResponseNormalization; m_ParserFunctions[tflite::BuiltinOperator_LOGICAL_NOT] = &TfLiteParserImpl::ParseLogicalNot; m_ParserFunctions[tflite::BuiltinOperator_LOGISTIC] = &TfLiteParserImpl::ParseLogistic; + m_ParserFunctions[tflite::BuiltinOperator_LOG_SOFTMAX] = &TfLiteParserImpl::ParseLogSoftmax; m_ParserFunctions[tflite::BuiltinOperator_L2_NORMALIZATION] = &TfLiteParserImpl::ParseL2Normalization; m_ParserFunctions[tflite::BuiltinOperator_MAX_POOL_2D] = &TfLiteParserImpl::ParseMaxPool2D; m_ParserFunctions[tflite::BuiltinOperator_MAXIMUM] = &TfLiteParserImpl::ParseMaximum; @@ -1874,6 +1875,33 @@ void TfLiteParserImpl::ParseSoftmax(size_t subgraphIndex, size_t operatorIndex) RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]}); } +void TfLiteParserImpl::ParseLogSoftmax(size_t subgraphIndex, size_t operatorIndex) +{ + CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); + + LogSoftmaxDescriptor desc; + + auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(inputs.size(), 1); + auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + auto layerName = fmt::format("LogSoftmax:{}:{}", subgraphIndex, operatorIndex); + IConnectableLayer* const layer = m_Network->AddLogSoftmaxLayer(desc, layerName.c_str()); + + armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0], true); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + // register the input connection slots for the layer, connections are made after all layers have been created + // only the tensors for the inputs are relevant, exclude the const tensors + auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]}); + + // register the output connection slots for the layer, connections are made after all layers have been created + auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]}); +} + void TfLiteParserImpl::ParseSpaceToBatchND(size_t subgraphIndex, size_t operatorIndex) { CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); |