diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2019-04-18 16:56:19 +0100 |
---|---|---|
committer | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2019-04-18 16:06:57 +0000 |
commit | 17660e68c91d48bfb3fc3c9540a1834f33e9e561 (patch) | |
tree | 4bcd63260230563323144c91b4256270a86b1215 /src/armnnTfLiteParser/TfLiteParser.cpp | |
parent | 7997a3527218ed821ec933ef3a5e6a3f07409b21 (diff) | |
download | armnn-17660e68c91d48bfb3fc3c9540a1834f33e9e561.tar.gz |
IVGCVSW-2987 Modify ParseSplit in TfLite parser
* Allow input data with dimension not greater than 4D
* Correct input order
* Get split dimension from buffer data
* Unit tests
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I285851b19e6fa7c715e5fe4853df167e7c856647
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 1ee4950558..b7258b3ffc 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -1971,11 +1971,15 @@ void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex) auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); CHECK_VALID_SIZE(outputs.size(), numSplits); - armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); - armnn::TensorInfo axisTensorInfo = ToTensorInfo(inputs[1]); + armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[1]); + armnn::TensorInfo axisTensorInfo = ToTensorInfo(inputs[0]); + + BufferRawPtr axisBufferPtr = GetBuffer(m_Model, inputs[0]->buffer); + std::vector<unsigned int> axisData(axisTensorInfo.GetNumElements()); + ::memcpy(axisData.data(), axisBufferPtr->data.data(), axisTensorInfo.GetNumBytes()); - // This splitDim indicates the data format: 3 is the NHWC, 1 is the NCHW. - const unsigned int splitDim = static_cast<unsigned int>(axisTensorInfo.GetShape()[0]); + BOOST_ASSERT(axisTensorInfo.GetNumElements() == 1); + const unsigned int splitDim = axisData[0]; // Armnn supports split along the channel dimension for data formats NHWC and NCHW. if (splitDim == 0 || splitDim == 2) @@ -1989,13 +1993,13 @@ void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex) } auto inputDimSize = inputTensorInfo.GetNumDimensions(); - if (inputDimSize != MaxNumOfTensorDimensions) + if (inputDimSize > MaxNumOfTensorDimensions) { throw ParseException( boost::str( boost::format( "The number of dimensions: %1% for input tensors of the " - "split op should be %2% %3%") + "split op cannot be greater than %2% %3%") % inputTensorInfo.GetNumDimensions() % MaxNumOfTensorDimensions % CHECK_LOCATION().AsString())); @@ -2015,7 +2019,7 @@ void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex) } splitterDimSizes[splitDim] /= numSplits; - SplitterDescriptor splitDesc(numSplits); + SplitterDescriptor splitDesc(numSplits, inputDimSize); for (unsigned int j = 0; j < numSplits; ++j) { // Set the size of the views. @@ -2030,7 +2034,7 @@ void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex) IConnectableLayer* layer = m_Network->AddSplitterLayer(splitDesc, layerName.c_str()); auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); - RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]}); + RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[1]}); TensorShape outShape = TensorShape(static_cast<unsigned int>(splitterDimSizes.size()), splitterDimSizes.data()); |