aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp89
1 files changed, 89 insertions, 0 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index b9a3522736..c00c2188a9 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -464,6 +464,7 @@ TfLiteParser::TfLiteParser()
m_ParserFunctions[tflite::BuiltinOperator_MUL] = &TfLiteParser::ParseMul;
m_ParserFunctions[tflite::BuiltinOperator_MEAN] = &TfLiteParser::ParseMean;
m_ParserFunctions[tflite::BuiltinOperator_PAD] = &TfLiteParser::ParsePad;
+ m_ParserFunctions[tflite::BuiltinOperator_SPLIT] = &TfLiteParser::ParseSplit;
}
void TfLiteParser::ResetParser()
@@ -1851,6 +1852,94 @@ void TfLiteParser::ParseDetectionPostProcess(size_t subgraphIndex, size_t operat
outputTensorIndexes[3]});
}
+void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex)
+{
+ CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
+
+ const auto & operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
+ const auto * options = operatorPtr->builtin_options.AsSplitOptions();
+
+ const unsigned int numSplits = CHECKED_NON_NEGATIVE(options->num_splits);
+
+ auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
+ CHECK_VALID_SIZE(inputs.size(), 2);
+ auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
+ CHECK_VALID_SIZE(outputs.size(), numSplits);
+
+ armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
+ armnn::TensorInfo axisTensorInfo = ToTensorInfo(inputs[1]);
+
+ // This splitDim indicates the data format: 3 is the NHWC, 1 is the NCHW.
+ const unsigned int splitDim = static_cast<unsigned int>(axisTensorInfo.GetShape()[0]);
+
+ // Armnn supports split along the channel dimension for data formats NHWC and NCHW.
+ if (splitDim == 0 || splitDim == 2)
+ {
+ throw ParseException(
+ boost::str(
+ boost::format(
+ "Dimension %1% for split is not supported by Armnn. %2%")
+ % splitDim
+ % CHECK_LOCATION().AsString()));
+ }
+
+ auto inputDimSize = inputTensorInfo.GetNumDimensions();
+ if (inputDimSize != MaxNumOfTensorDimensions)
+ {
+ throw ParseException(
+ boost::str(
+ boost::format(
+ "The number of dimensions: %1% for input tensors of the "
+ "split op should be %2% %3%")
+ % inputTensorInfo.GetNumDimensions()
+ % MaxNumOfTensorDimensions
+ % CHECK_LOCATION().AsString()));
+ }
+
+ std::vector<unsigned int> splitterDimSizes(inputDimSize);
+
+ // Add current input shape to splitterDimSizes
+ for (unsigned int i = 0; i < inputDimSize; ++i)
+ {
+ splitterDimSizes[i] = inputTensorInfo.GetShape()[i];
+ }
+
+ if (splitterDimSizes[splitDim] % numSplits != 0)
+ {
+ throw ParseException("Number of splits must evenly divide the dimension");
+ }
+ splitterDimSizes[splitDim] /= numSplits;
+
+ SplitterDescriptor splitDesc(numSplits);
+ for (unsigned int j = 0; j < numSplits; ++j)
+ {
+ // Set the size of the views.
+ for (unsigned int dimIdx = 0; dimIdx < splitterDimSizes.size(); ++dimIdx)
+ {
+ splitDesc.SetViewSize(j, dimIdx, splitterDimSizes[dimIdx]);
+ }
+ splitDesc.SetViewOriginCoord(j, splitDim, splitterDimSizes[splitDim] * j);
+ }
+
+ auto layerName = boost::str(boost::format("Split:%1%:%2%") % subgraphIndex % operatorIndex);
+ IConnectableLayer* layer = m_Network->AddSplitterLayer(splitDesc, layerName.c_str());
+
+ auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
+
+ TensorShape outShape = TensorShape(static_cast<unsigned int>(splitterDimSizes.size()),
+ splitterDimSizes.data());
+
+ for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
+ {
+ layer->GetOutputSlot(k).SetTensorInfo(armnn::TensorInfo(outShape,
+ inputTensorInfo.GetDataType()));
+ }
+
+ auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
+}
+
armnn::IConnectableLayer* TfLiteParser::AddFusedActivationLayer(armnn::IConnectableLayer* prevLayer,
unsigned int outputSlot,
tflite::ActivationFunctionType activationType)