diff options
author | Teresa Charlin <teresa.charlinreyes@arm.com> | 2022-06-29 15:35:57 +0100 |
---|---|---|
committer | Nikhil Raj <nikhil.raj@arm.com> | 2022-07-08 15:20:34 +0100 |
commit | fd33a698ee3c588aa4064b70b7781ab25ff76f66 (patch) | |
tree | 96c5bd3fa17623c49015df13f92ea06e32dfc61f /src/armnnTfLiteParser/TfLiteParser.cpp | |
parent | 452c58080e9f8f577de87e0c07d0097aac97f3b8 (diff) | |
download | armnn-fd33a698ee3c588aa4064b70b7781ab25ff76f66.tar.gz |
IVGCVSW-7040 Add support for LOG_SOFTMAX to the TFLiteParser
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: I1fedfdf2cd8871d6b307fce8620f40adadf75f04
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 30 |
1 files changed, 29 insertions, 1 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 479fc4f474..5dbb5ee503 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -710,6 +710,7 @@ TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOpt = &TfLiteParserImpl::ParseLocalResponseNormalization; m_ParserFunctions[tflite::BuiltinOperator_LOGICAL_NOT] = &TfLiteParserImpl::ParseLogicalNot; m_ParserFunctions[tflite::BuiltinOperator_LOGISTIC] = &TfLiteParserImpl::ParseLogistic; + m_ParserFunctions[tflite::BuiltinOperator_LOG_SOFTMAX] = &TfLiteParserImpl::ParseLogSoftmax; m_ParserFunctions[tflite::BuiltinOperator_L2_NORMALIZATION] = &TfLiteParserImpl::ParseL2Normalization; m_ParserFunctions[tflite::BuiltinOperator_MAX_POOL_2D] = &TfLiteParserImpl::ParseMaxPool2D; m_ParserFunctions[tflite::BuiltinOperator_MAXIMUM] = &TfLiteParserImpl::ParseMaximum; @@ -1874,6 +1875,33 @@ void TfLiteParserImpl::ParseSoftmax(size_t subgraphIndex, size_t operatorIndex) RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]}); } +void TfLiteParserImpl::ParseLogSoftmax(size_t subgraphIndex, size_t operatorIndex) +{ + CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); + + LogSoftmaxDescriptor desc; + + auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(inputs.size(), 1); + auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + auto layerName = fmt::format("LogSoftmax:{}:{}", subgraphIndex, operatorIndex); + IConnectableLayer* const layer = m_Network->AddLogSoftmaxLayer(desc, layerName.c_str()); + + armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0], true); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + // register the input connection slots for the layer, connections are made after all layers have been created + // only the tensors for the inputs are relevant, exclude the const tensors + auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]}); + + // register the output connection slots for the layer, connections are made after all layers have been created + auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]}); +} + void TfLiteParserImpl::ParseSpaceToBatchND(size_t subgraphIndex, size_t operatorIndex) { CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); |