diff options
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 61 |
1 files changed, 61 insertions, 0 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 052aac6101..3f4f0d811f 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -750,6 +750,7 @@ TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOpt m_ParserFunctions[tflite::BuiltinOperator_AVERAGE_POOL_2D] = &TfLiteParserImpl::ParseAveragePool2D; m_ParserFunctions[tflite::BuiltinOperator_BATCH_TO_SPACE_ND] = &TfLiteParserImpl::ParseBatchToSpaceND; m_ParserFunctions[tflite::BuiltinOperator_BATCH_MATMUL] = &TfLiteParserImpl::ParseBatchMatMul; + m_ParserFunctions[tflite::BuiltinOperator_BROADCAST_TO] = &TfLiteParserImpl::ParseBroadcastTo; m_ParserFunctions[tflite::BuiltinOperator_CEIL] = &TfLiteParserImpl::ParseCeil; m_ParserFunctions[tflite::BuiltinOperator_CAST] = &TfLiteParserImpl::ParseCast; m_ParserFunctions[tflite::BuiltinOperator_CONCATENATION] = &TfLiteParserImpl::ParseConcatenation; @@ -1894,6 +1895,66 @@ void TfLiteParserImpl::ParseBatchToSpaceND(size_t subgraphIndex, size_t operator RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]}); } +void TfLiteParserImpl::ParseBroadcastTo(size_t subgraphIndex, size_t operatorIndex) +{ + CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); + + auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(inputs.size(), 2); + + auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); + TensorInfo shapeTensorInfo = ToTensorInfo(inputs[1]); + TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); + + auto layerName = fmt::format("Broadcast_to:{}:{}", subgraphIndex, operatorIndex); + + BroadcastToDescriptor descriptor; + + auto shapeBufferPtr = GetBuffer(m_Model, inputs[1]->buffer); + if (shapeBufferPtr != nullptr) + { + std::vector<unsigned int> targetShape; + unsigned int numElement = shapeTensorInfo.GetNumElements(); + auto shapeData = reinterpret_cast<const int32_t*>(shapeBufferPtr->data.data()); + if (shapeData) + { + for (unsigned int i = 0; i < numElement; ++i) + { + targetShape.push_back(armnn::numeric_cast<unsigned int>(shapeData[i])); + } + descriptor.m_BroadcastToShape = TensorShape(numElement, targetShape.data()); + } + /// get dataShape from outputShape if missing + else + { + if(outputTensorInfo.GetShape().GetNumElements() <= 1) + { + ARMNN_THROW_PARSE_EXCEPTION("For Broadcast_to layer, " + "data and output shape are not found in the buffer."); + } + descriptor.m_BroadcastToShape = outputTensorInfo.GetShape(); + } + } + else + { + ARMNN_THROW_PARSE_EXCEPTION("For Broadcast_to layer, Shape data was not found in the buffer."); + } + + IConnectableLayer* layer = m_Network->AddBroadcastToLayer(descriptor, layerName.c_str()); + ARMNN_ASSERT(layer != nullptr); + + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]}); + + auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]}); +} + void TfLiteParserImpl::ParseL2Normalization(size_t subgraphIndex, size_t operatorIndex) { CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); |