aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp20
1 files changed, 12 insertions, 8 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 1ee4950558..b7258b3ffc 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -1971,11 +1971,15 @@ void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex)
auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
CHECK_VALID_SIZE(outputs.size(), numSplits);
- armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
- armnn::TensorInfo axisTensorInfo = ToTensorInfo(inputs[1]);
+ armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[1]);
+ armnn::TensorInfo axisTensorInfo = ToTensorInfo(inputs[0]);
+
+ BufferRawPtr axisBufferPtr = GetBuffer(m_Model, inputs[0]->buffer);
+ std::vector<unsigned int> axisData(axisTensorInfo.GetNumElements());
+ ::memcpy(axisData.data(), axisBufferPtr->data.data(), axisTensorInfo.GetNumBytes());
- // This splitDim indicates the data format: 3 is the NHWC, 1 is the NCHW.
- const unsigned int splitDim = static_cast<unsigned int>(axisTensorInfo.GetShape()[0]);
+ BOOST_ASSERT(axisTensorInfo.GetNumElements() == 1);
+ const unsigned int splitDim = axisData[0];
// Armnn supports split along the channel dimension for data formats NHWC and NCHW.
if (splitDim == 0 || splitDim == 2)
@@ -1989,13 +1993,13 @@ void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex)
}
auto inputDimSize = inputTensorInfo.GetNumDimensions();
- if (inputDimSize != MaxNumOfTensorDimensions)
+ if (inputDimSize > MaxNumOfTensorDimensions)
{
throw ParseException(
boost::str(
boost::format(
"The number of dimensions: %1% for input tensors of the "
- "split op should be %2% %3%")
+ "split op cannot be greater than %2% %3%")
% inputTensorInfo.GetNumDimensions()
% MaxNumOfTensorDimensions
% CHECK_LOCATION().AsString()));
@@ -2015,7 +2019,7 @@ void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex)
}
splitterDimSizes[splitDim] /= numSplits;
- SplitterDescriptor splitDesc(numSplits);
+ SplitterDescriptor splitDesc(numSplits, inputDimSize);
for (unsigned int j = 0; j < numSplits; ++j)
{
// Set the size of the views.
@@ -2030,7 +2034,7 @@ void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex)
IConnectableLayer* layer = m_Network->AddSplitterLayer(splitDesc, layerName.c_str());
auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
- RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
+ RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[1]});
TensorShape outShape = TensorShape(static_cast<unsigned int>(splitterDimSizes.size()),
splitterDimSizes.data());