diff options
author | Idriss Chaouch <idriss.chaouch@arm.com> | 2023-09-01 17:58:38 +0100 |
---|---|---|
committer | idriss.chaouch <idriss.chaouch@arm.com> | 2023-09-08 08:32:43 +0000 |
commit | 564c13dc098eb9353ac15e2609712ab8db9bf350 (patch) | |
tree | 6cb52e904e3cd001d650a6386b1105ee21b08847 /src/armnnTfLiteParser/TfLiteParser.cpp | |
parent | 04e3eb5d339c3778f26c69651bf1464c8ab5331d (diff) | |
download | armnn-564c13dc098eb9353ac15e2609712ab8db9bf350.tar.gz |
IVGCVSW-7525 Add broadcast_to to TFLite Parser
* Changing the optimizer
* Changing EndToEnd Tests
Signed-off-by: Idriss Chaouch <idriss.chaouch@arm.com>
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: Ib581794280322a39cfc5ea3c4e6a6398cf723d5e
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 61 |
1 files changed, 61 insertions, 0 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 052aac6101..3f4f0d811f 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -750,6 +750,7 @@ TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOpt m_ParserFunctions[tflite::BuiltinOperator_AVERAGE_POOL_2D] = &TfLiteParserImpl::ParseAveragePool2D; m_ParserFunctions[tflite::BuiltinOperator_BATCH_TO_SPACE_ND] = &TfLiteParserImpl::ParseBatchToSpaceND; m_ParserFunctions[tflite::BuiltinOperator_BATCH_MATMUL] = &TfLiteParserImpl::ParseBatchMatMul; + m_ParserFunctions[tflite::BuiltinOperator_BROADCAST_TO] = &TfLiteParserImpl::ParseBroadcastTo; m_ParserFunctions[tflite::BuiltinOperator_CEIL] = &TfLiteParserImpl::ParseCeil; m_ParserFunctions[tflite::BuiltinOperator_CAST] = &TfLiteParserImpl::ParseCast; m_ParserFunctions[tflite::BuiltinOperator_CONCATENATION] = &TfLiteParserImpl::ParseConcatenation; @@ -1894,6 +1895,66 @@ void TfLiteParserImpl::ParseBatchToSpaceND(size_t subgraphIndex, size_t operator RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]}); } +void TfLiteParserImpl::ParseBroadcastTo(size_t subgraphIndex, size_t operatorIndex) +{ + CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); + + auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(inputs.size(), 2); + + auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); + TensorInfo shapeTensorInfo = ToTensorInfo(inputs[1]); + TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); + + auto layerName = fmt::format("Broadcast_to:{}:{}", subgraphIndex, operatorIndex); + + BroadcastToDescriptor descriptor; + + auto shapeBufferPtr = GetBuffer(m_Model, inputs[1]->buffer); + if (shapeBufferPtr != nullptr) + { + std::vector<unsigned int> targetShape; + unsigned int numElement = shapeTensorInfo.GetNumElements(); + auto shapeData = reinterpret_cast<const int32_t*>(shapeBufferPtr->data.data()); + if (shapeData) + { + for (unsigned int i = 0; i < numElement; ++i) + { + targetShape.push_back(armnn::numeric_cast<unsigned int>(shapeData[i])); + } + descriptor.m_BroadcastToShape = TensorShape(numElement, targetShape.data()); + } + /// get dataShape from outputShape if missing + else + { + if(outputTensorInfo.GetShape().GetNumElements() <= 1) + { + ARMNN_THROW_PARSE_EXCEPTION("For Broadcast_to layer, " + "data and output shape are not found in the buffer."); + } + descriptor.m_BroadcastToShape = outputTensorInfo.GetShape(); + } + } + else + { + ARMNN_THROW_PARSE_EXCEPTION("For Broadcast_to layer, Shape data was not found in the buffer."); + } + + IConnectableLayer* layer = m_Network->AddBroadcastToLayer(descriptor, layerName.c_str()); + ARMNN_ASSERT(layer != nullptr); + + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]}); + + auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]}); +} + void TfLiteParserImpl::ParseL2Normalization(size_t subgraphIndex, size_t operatorIndex) { CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); |