From 17660e68c91d48bfb3fc3c9540a1834f33e9e561 Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Thu, 18 Apr 2019 16:56:19 +0100 Subject: IVGCVSW-2987 Modify ParseSplit in TfLite parser * Allow input data with dimension not greater than 4D * Correct input order * Get split dimension from buffer data * Unit tests Signed-off-by: Narumol Prangnawarat Change-Id: I285851b19e6fa7c715e5fe4853df167e7c856647 --- src/armnnTfLiteParser/TfLiteParser.cpp | 20 ++++++----- src/armnnTfLiteParser/test/Split.cpp | 62 ++++++++++++++++++++++++++-------- 2 files changed, 59 insertions(+), 23 deletions(-) diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 1ee4950558..b7258b3ffc 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -1971,11 +1971,15 @@ void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex) auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); CHECK_VALID_SIZE(outputs.size(), numSplits); - armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); - armnn::TensorInfo axisTensorInfo = ToTensorInfo(inputs[1]); + armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[1]); + armnn::TensorInfo axisTensorInfo = ToTensorInfo(inputs[0]); + + BufferRawPtr axisBufferPtr = GetBuffer(m_Model, inputs[0]->buffer); + std::vector axisData(axisTensorInfo.GetNumElements()); + ::memcpy(axisData.data(), axisBufferPtr->data.data(), axisTensorInfo.GetNumBytes()); - // This splitDim indicates the data format: 3 is the NHWC, 1 is the NCHW. - const unsigned int splitDim = static_cast(axisTensorInfo.GetShape()[0]); + BOOST_ASSERT(axisTensorInfo.GetNumElements() == 1); + const unsigned int splitDim = axisData[0]; // Armnn supports split along the channel dimension for data formats NHWC and NCHW. if (splitDim == 0 || splitDim == 2) @@ -1989,13 +1993,13 @@ void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex) } auto inputDimSize = inputTensorInfo.GetNumDimensions(); - if (inputDimSize != MaxNumOfTensorDimensions) + if (inputDimSize > MaxNumOfTensorDimensions) { throw ParseException( boost::str( boost::format( "The number of dimensions: %1% for input tensors of the " - "split op should be %2% %3%") + "split op cannot be greater than %2% %3%") % inputTensorInfo.GetNumDimensions() % MaxNumOfTensorDimensions % CHECK_LOCATION().AsString())); @@ -2015,7 +2019,7 @@ void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex) } splitterDimSizes[splitDim] /= numSplits; - SplitterDescriptor splitDesc(numSplits); + SplitterDescriptor splitDesc(numSplits, inputDimSize); for (unsigned int j = 0; j < numSplits; ++j) { // Set the size of the views. @@ -2030,7 +2034,7 @@ void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex) IConnectableLayer* layer = m_Network->AddSplitterLayer(splitDesc, layerName.c_str()); auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); - RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]}); + RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[1]}); TensorShape outShape = TensorShape(static_cast(splitterDimSizes.size()), splitterDimSizes.data()); diff --git a/src/armnnTfLiteParser/test/Split.cpp b/src/armnnTfLiteParser/test/Split.cpp index 774a416750..a6875143fa 100644 --- a/src/armnnTfLiteParser/test/Split.cpp +++ b/src/armnnTfLiteParser/test/Split.cpp @@ -14,11 +14,12 @@ BOOST_AUTO_TEST_SUITE(TensorflowLiteParser) struct SplitFixture : public ParserFlatbuffersFixture { - explicit SplitFixture(const std::string & inputShape, - const std::string & axisShape, - const std::string & numSplits, - const std::string & outputShape1, - const std::string & outputShape2) + explicit SplitFixture(const std::string& inputShape, + const std::string& axisShape, + const std::string& numSplits, + const std::string& outputShape1, + const std::string& outputShape2, + const std::string& axisData) { m_JsonString = R"( { @@ -75,12 +76,12 @@ struct SplitFixture : public ParserFlatbuffersFixture } } ], - "inputs": [ 0, 1 ], + "inputs": [ 0 ], "outputs": [ 2, 3 ], "operators": [ { "opcode_index": 0, - "inputs": [ 0, 1 ], + "inputs": [ 1, 0 ], "outputs": [ 2, 3 ], "builtin_options_type": "SplitOptions", "builtin_options": { @@ -90,7 +91,7 @@ struct SplitFixture : public ParserFlatbuffersFixture } ], } ], - "buffers" : [ {}, {} ] + "buffers" : [ {}, {"data": )" + axisData + R"( }, {}, {} ] } )"; @@ -101,8 +102,8 @@ struct SplitFixture : public ParserFlatbuffersFixture struct SimpleSplitFixture : SplitFixture { - SimpleSplitFixture() : SplitFixture( "[ 2, 2, 2, 2 ]", "[ 1 ]", "2", - "[ 2, 1, 2, 2 ]", "[ 2, 1, 2, 2 ]") + SimpleSplitFixture() : SplitFixture( "[ 2, 2, 2, 2 ]", "[ ]", "2", + "[ 2, 1, 2, 2 ]", "[ 2, 1, 2, 2 ]", "[ 1, 0, 0, 0 ]") {} }; @@ -113,14 +114,14 @@ BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwo, SimpleSplitFixture) 0, { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } }, - { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 9.0f, 10.0f, 11.0f, 12.0f }}, - {"outputTensor2", { 5.0f, 6.0f, 7.0f, 8.0f, 13.0f, 14.0f, 15.0f, 16.0f }}}); + { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 9.0f, 10.0f, 11.0f, 12.0f } }, + {"outputTensor2", { 5.0f, 6.0f, 7.0f, 8.0f, 13.0f, 14.0f, 15.0f, 16.0f } } }); } struct SimpleSplitAxisThreeFixture : SplitFixture { - SimpleSplitAxisThreeFixture() : SplitFixture( "[ 2, 2, 2, 2 ]", "[ 3 ]", "2", - "[ 2, 2, 2, 1 ]", "[ 2, 2, 2, 1 ]") + SimpleSplitAxisThreeFixture() : SplitFixture( "[ 2, 2, 2, 2 ]", "[ ]", "2", + "[ 2, 2, 2, 1 ]", "[ 2, 2, 2, 1 ]", "[ 3, 0, 0, 0 ]") {} }; @@ -130,8 +131,39 @@ BOOST_FIXTURE_TEST_CASE(ParseAxisThreeSplitTwo, SimpleSplitAxisThreeFixture) 0, { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } }, - { {"outputTensor1", { 1.0f, 3.0f, 5.0f, 7.0f, 9.0f, 11.0f, 13.0f, 15.0f }}, + { {"outputTensor1", { 1.0f, 3.0f, 5.0f, 7.0f, 9.0f, 11.0f, 13.0f, 15.0f } }, {"outputTensor2", { 2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f, 14.0f, 16.0f } } } ); } +struct SimpleSplit2DFixture : SplitFixture +{ + SimpleSplit2DFixture() : SplitFixture( "[ 1, 8 ]", "[ ]", "2", "[ 1, 4 ]", "[ 1, 4 ]", "[ 1, 0, 0, 0 ]") + {} +}; + +BOOST_FIXTURE_TEST_CASE(SimpleSplit2D, SimpleSplit2DFixture) +{ + RunTest<2, armnn::DataType::Float32>( + 0, + { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } } }, + { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f } }, + {"outputTensor2", { 5.0f, 6.0f, 7.0f, 8.0f } } } ); +} + +struct SimpleSplit3DFixture : SplitFixture +{ + SimpleSplit3DFixture() : SplitFixture( "[ 1, 8, 2 ]", "[ ]", "2", "[ 1, 4, 2 ]", "[ 1, 4, 2 ]", "[ 1, 0, 0, 0 ]") + {} +}; + +BOOST_FIXTURE_TEST_CASE(SimpleSplit3D, SimpleSplit3DFixture) +{ + RunTest<3, armnn::DataType::Float32>( + 0, + { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, + 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } }, + { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } }, + {"outputTensor2", { 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } } ); +} + BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file -- cgit v1.2.1