diff options
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 49 |
1 files changed, 49 insertions, 0 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 4acd30805e..31aab029ab 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -419,6 +419,7 @@ TfLiteParser::TfLiteParser() { // register supported operators m_ParserFunctions[tflite::BuiltinOperator_AVERAGE_POOL_2D] = &TfLiteParser::ParseAveragePool2D; + m_ParserFunctions[tflite::BuiltinOperator_BATCH_TO_SPACE_ND] = &TfLiteParser::ParseBatchToSpaceND; m_ParserFunctions[tflite::BuiltinOperator_CONCATENATION] = &TfLiteParser::ParseConcatenation; m_ParserFunctions[tflite::BuiltinOperator_CONV_2D] = &TfLiteParser::ParseConv2D; m_ParserFunctions[tflite::BuiltinOperator_DEPTHWISE_CONV_2D] = &TfLiteParser::ParseDepthwiseConv2D; @@ -836,6 +837,54 @@ void TfLiteParser::ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex ParsePool(subgraphIndex, operatorIndex, PoolingAlgorithm::Average); } +void TfLiteParser::ParseBatchToSpaceND(size_t subgraphIndex, size_t operatorIndex) +{ + CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); + + auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(inputs.size(), 3); + + auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + armnn::TensorInfo blockShapeTensorInfo = ToTensorInfo(inputs[1]); + BufferRawPtr blockShapeBufferPtr = GetBuffer(m_Model, inputs[1]->buffer); + + armnn::TensorInfo cropsTensorInfo = ToTensorInfo(inputs[2]); + BufferRawPtr cropsBufferPtr = GetBuffer(m_Model, inputs[2]->buffer); + + std::vector<unsigned int> blockShape(blockShapeTensorInfo.GetNumElements()); + ::memcpy(blockShape.data(), blockShapeBufferPtr->data.data(), blockShapeTensorInfo.GetNumBytes()); + + std::vector<unsigned int> cropsVector(cropsTensorInfo.GetNumElements()); + ::memcpy(cropsVector.data(), cropsBufferPtr->data.data(), cropsTensorInfo.GetNumBytes()); + + size_t step = 2; + std::vector<std::pair<unsigned int, unsigned int>> crops; + for (unsigned int i = 0; i < cropsTensorInfo.GetNumElements() / step; ++i) + { + crops.emplace_back(cropsVector[i * step], cropsVector[i * step + 1]); + } + + armnn::BatchToSpaceNdDescriptor desc; + desc.m_BlockShape = blockShape; + desc.m_Crops = crops; + desc.m_DataLayout = armnn::DataLayout::NHWC; + + armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); + + auto layerName = boost::str(boost::format("BatchToSpaceND:%1%:%2%") % subgraphIndex % operatorIndex); + IConnectableLayer* layer = m_Network->AddBatchToSpaceNdLayer(desc, layerName.c_str()); + + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]}); + + auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]}); +} + void TfLiteParser::ParseMaxPool2D(size_t subgraphIndex, size_t operatorIndex) { ParsePool(subgraphIndex, operatorIndex, PoolingAlgorithm::Max); |