diff options
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 47 |
1 files changed, 46 insertions, 1 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 049d6049a7..3fd81ff973 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -810,6 +810,7 @@ TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOpt m_ParserFunctions[tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR] = &TfLiteParserImpl::ParseResizeNearestNeighbor; m_ParserFunctions[tflite::BuiltinOperator_REVERSE_V2] = &TfLiteParserImpl::ParseReverseV2; m_ParserFunctions[tflite::BuiltinOperator_RSQRT] = &TfLiteParserImpl::ParseRsqrt; + m_ParserFunctions[tflite::BuiltinOperator_SCATTER_ND] = &TfLiteParserImpl::ParseScatterNd; m_ParserFunctions[tflite::BuiltinOperator_SQRT] = &TfLiteParserImpl::ParseSqrt; m_ParserFunctions[tflite::BuiltinOperator_SHAPE] = &TfLiteParserImpl::ParseShape; m_ParserFunctions[tflite::BuiltinOperator_SIN] = &TfLiteParserImpl::ParseSin; @@ -2279,6 +2280,50 @@ void TfLiteParserImpl::ParseLogSoftmax(size_t subgraphIndex, size_t operatorInde RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]}); } +void TfLiteParserImpl::ParseScatterNd(size_t subgraphIndex, size_t operatorIndex) +{ + CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); + + TfLiteParserImpl::TensorRawPtrVector inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(inputs.size(), 3); + TfLiteParserImpl::TensorRawPtrVector outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + armnn::TensorInfo indicesTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0); + armnn::TensorInfo updatesTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1); + armnn::TensorInfo shapeTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 2); + armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); + + // TFLite currently only have these options: update and no input given, just shape. + armnn::ScatterNdDescriptor descriptor(armnn::ScatterNdFunction::Update, false); + + const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex]; + const auto* options = operatorPtr->builtin_options.AsScatterNdOptions(); + IgnoreUnused(options); + + auto layerName = fmt::format("ScatterND:{}:{}", subgraphIndex, operatorIndex); + + IConnectableLayer* layer = m_Network->AddScatterNdLayer(descriptor, layerName.c_str()); + + if (!layer) + { + throw NullPointerException(fmt::format("Layer {} pointer is null {}", + operatorIndex, CHECK_LOCATION().AsString())); + } + + outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0, 1, 2}); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterInputSlots(subgraphIndex, + operatorIndex, + layer, + {inputTensorIndexes[2], inputTensorIndexes[0], inputTensorIndexes[1]}); + + auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]}); +} + void TfLiteParserImpl::ParseSpaceToBatchND(size_t subgraphIndex, size_t operatorIndex) { CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); |