aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2019-04-18 16:56:19 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2019-04-18 16:06:57 +0000
commit17660e68c91d48bfb3fc3c9540a1834f33e9e561 (patch)
tree4bcd63260230563323144c91b4256270a86b1215
parent7997a3527218ed821ec933ef3a5e6a3f07409b21 (diff)
downloadarmnn-17660e68c91d48bfb3fc3c9540a1834f33e9e561.tar.gz
IVGCVSW-2987 Modify ParseSplit in TfLite parser
* Allow input data with dimension not greater than 4D * Correct input order * Get split dimension from buffer data * Unit tests Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I285851b19e6fa7c715e5fe4853df167e7c856647
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp20
-rw-r--r--src/armnnTfLiteParser/test/Split.cpp62
2 files changed, 59 insertions, 23 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 1ee4950558..b7258b3ffc 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -1971,11 +1971,15 @@ void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex)
auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
CHECK_VALID_SIZE(outputs.size(), numSplits);
- armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
- armnn::TensorInfo axisTensorInfo = ToTensorInfo(inputs[1]);
+ armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[1]);
+ armnn::TensorInfo axisTensorInfo = ToTensorInfo(inputs[0]);
+
+ BufferRawPtr axisBufferPtr = GetBuffer(m_Model, inputs[0]->buffer);
+ std::vector<unsigned int> axisData(axisTensorInfo.GetNumElements());
+ ::memcpy(axisData.data(), axisBufferPtr->data.data(), axisTensorInfo.GetNumBytes());
- // This splitDim indicates the data format: 3 is the NHWC, 1 is the NCHW.
- const unsigned int splitDim = static_cast<unsigned int>(axisTensorInfo.GetShape()[0]);
+ BOOST_ASSERT(axisTensorInfo.GetNumElements() == 1);
+ const unsigned int splitDim = axisData[0];
// Armnn supports split along the channel dimension for data formats NHWC and NCHW.
if (splitDim == 0 || splitDim == 2)
@@ -1989,13 +1993,13 @@ void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex)
}
auto inputDimSize = inputTensorInfo.GetNumDimensions();
- if (inputDimSize != MaxNumOfTensorDimensions)
+ if (inputDimSize > MaxNumOfTensorDimensions)
{
throw ParseException(
boost::str(
boost::format(
"The number of dimensions: %1% for input tensors of the "
- "split op should be %2% %3%")
+ "split op cannot be greater than %2% %3%")
% inputTensorInfo.GetNumDimensions()
% MaxNumOfTensorDimensions
% CHECK_LOCATION().AsString()));
@@ -2015,7 +2019,7 @@ void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex)
}
splitterDimSizes[splitDim] /= numSplits;
- SplitterDescriptor splitDesc(numSplits);
+ SplitterDescriptor splitDesc(numSplits, inputDimSize);
for (unsigned int j = 0; j < numSplits; ++j)
{
// Set the size of the views.
@@ -2030,7 +2034,7 @@ void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex)
IConnectableLayer* layer = m_Network->AddSplitterLayer(splitDesc, layerName.c_str());
auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
- RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
+ RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[1]});
TensorShape outShape = TensorShape(static_cast<unsigned int>(splitterDimSizes.size()),
splitterDimSizes.data());
diff --git a/src/armnnTfLiteParser/test/Split.cpp b/src/armnnTfLiteParser/test/Split.cpp
index 774a416750..a6875143fa 100644
--- a/src/armnnTfLiteParser/test/Split.cpp
+++ b/src/armnnTfLiteParser/test/Split.cpp
@@ -14,11 +14,12 @@ BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
struct SplitFixture : public ParserFlatbuffersFixture
{
- explicit SplitFixture(const std::string & inputShape,
- const std::string & axisShape,
- const std::string & numSplits,
- const std::string & outputShape1,
- const std::string & outputShape2)
+ explicit SplitFixture(const std::string& inputShape,
+ const std::string& axisShape,
+ const std::string& numSplits,
+ const std::string& outputShape1,
+ const std::string& outputShape2,
+ const std::string& axisData)
{
m_JsonString = R"(
{
@@ -75,12 +76,12 @@ struct SplitFixture : public ParserFlatbuffersFixture
}
}
],
- "inputs": [ 0, 1 ],
+ "inputs": [ 0 ],
"outputs": [ 2, 3 ],
"operators": [
{
"opcode_index": 0,
- "inputs": [ 0, 1 ],
+ "inputs": [ 1, 0 ],
"outputs": [ 2, 3 ],
"builtin_options_type": "SplitOptions",
"builtin_options": {
@@ -90,7 +91,7 @@ struct SplitFixture : public ParserFlatbuffersFixture
}
],
} ],
- "buffers" : [ {}, {} ]
+ "buffers" : [ {}, {"data": )" + axisData + R"( }, {}, {} ]
}
)";
@@ -101,8 +102,8 @@ struct SplitFixture : public ParserFlatbuffersFixture
struct SimpleSplitFixture : SplitFixture
{
- SimpleSplitFixture() : SplitFixture( "[ 2, 2, 2, 2 ]", "[ 1 ]", "2",
- "[ 2, 1, 2, 2 ]", "[ 2, 1, 2, 2 ]")
+ SimpleSplitFixture() : SplitFixture( "[ 2, 2, 2, 2 ]", "[ ]", "2",
+ "[ 2, 1, 2, 2 ]", "[ 2, 1, 2, 2 ]", "[ 1, 0, 0, 0 ]")
{}
};
@@ -113,14 +114,14 @@ BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwo, SimpleSplitFixture)
0,
{ {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } },
- { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 9.0f, 10.0f, 11.0f, 12.0f }},
- {"outputTensor2", { 5.0f, 6.0f, 7.0f, 8.0f, 13.0f, 14.0f, 15.0f, 16.0f }}});
+ { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 9.0f, 10.0f, 11.0f, 12.0f } },
+ {"outputTensor2", { 5.0f, 6.0f, 7.0f, 8.0f, 13.0f, 14.0f, 15.0f, 16.0f } } });
}
struct SimpleSplitAxisThreeFixture : SplitFixture
{
- SimpleSplitAxisThreeFixture() : SplitFixture( "[ 2, 2, 2, 2 ]", "[ 3 ]", "2",
- "[ 2, 2, 2, 1 ]", "[ 2, 2, 2, 1 ]")
+ SimpleSplitAxisThreeFixture() : SplitFixture( "[ 2, 2, 2, 2 ]", "[ ]", "2",
+ "[ 2, 2, 2, 1 ]", "[ 2, 2, 2, 1 ]", "[ 3, 0, 0, 0 ]")
{}
};
@@ -130,8 +131,39 @@ BOOST_FIXTURE_TEST_CASE(ParseAxisThreeSplitTwo, SimpleSplitAxisThreeFixture)
0,
{ {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } },
- { {"outputTensor1", { 1.0f, 3.0f, 5.0f, 7.0f, 9.0f, 11.0f, 13.0f, 15.0f }},
+ { {"outputTensor1", { 1.0f, 3.0f, 5.0f, 7.0f, 9.0f, 11.0f, 13.0f, 15.0f } },
{"outputTensor2", { 2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f, 14.0f, 16.0f } } } );
}
+struct SimpleSplit2DFixture : SplitFixture
+{
+ SimpleSplit2DFixture() : SplitFixture( "[ 1, 8 ]", "[ ]", "2", "[ 1, 4 ]", "[ 1, 4 ]", "[ 1, 0, 0, 0 ]")
+ {}
+};
+
+BOOST_FIXTURE_TEST_CASE(SimpleSplit2D, SimpleSplit2DFixture)
+{
+ RunTest<2, armnn::DataType::Float32>(
+ 0,
+ { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } } },
+ { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f } },
+ {"outputTensor2", { 5.0f, 6.0f, 7.0f, 8.0f } } } );
+}
+
+struct SimpleSplit3DFixture : SplitFixture
+{
+ SimpleSplit3DFixture() : SplitFixture( "[ 1, 8, 2 ]", "[ ]", "2", "[ 1, 4, 2 ]", "[ 1, 4, 2 ]", "[ 1, 0, 0, 0 ]")
+ {}
+};
+
+BOOST_FIXTURE_TEST_CASE(SimpleSplit3D, SimpleSplit3DFixture)
+{
+ RunTest<3, armnn::DataType::Float32>(
+ 0,
+ { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f,
+ 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } },
+ { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } },
+ {"outputTensor2", { 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } } );
+}
+
BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file