aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2019-04-23 15:28:06 +0100
committerRuomei Yan <ruomei.yan@arm.com>2019-04-25 10:26:19 +0000
commit672de578c819a3d815e037cddfaf4dcf49d12917 (patch)
treea85e588ce044b4733919b05bff0244d01a19cb93
parent4a95611748583f35f0f1f8388a3ba261e63436b7 (diff)
downloadarmnn-672de578c819a3d815e037cddfaf4dcf49d12917.tar.gz
IVGCVSW-2994 Add Reshape layer to ParseUnpack in TfLite parser
to remove the unpacked dimension of each output from Splitter and correct ReshapeFixtureWithReshapeDimsFlatten test output shape Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I517d315475612ac8b773930f9b58cac316fa8553
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp53
-rw-r--r--src/armnnTfLiteParser/test/Reshape.cpp6
2 files changed, 49 insertions, 10 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index b7258b3ffc..44b3614bb2 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -1888,6 +1888,19 @@ void TfLiteParser::ParseUnpack(size_t subgraphIndex, size_t operatorIndex)
CHECK_VALID_SIZE(inputs.size(), 1);
armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
+
+ if (unpackAxis >= inputTensorInfo.GetNumDimensions())
+ {
+ throw ParseException(
+ boost::str(
+ boost::format(
+ "The unpack axis: %1% cannot be greater than or equal to "
+ "the number of input dimension %2% %3%")
+ % unpackAxis
+ % inputTensorInfo.GetNumDimensions()
+ % CHECK_LOCATION().AsString()));
+ }
+
unsigned int unpackNum = CHECKED_NON_NEGATIVE(options->num);
// If num is not defined, automatically infer from the length of the dimension axis.
if(unpackNum == 0)
@@ -1935,20 +1948,46 @@ void TfLiteParser::ParseUnpack(size_t subgraphIndex, size_t operatorIndex)
auto layerName = boost::str(boost::format("Unpack:%1%:%2%") % subgraphIndex % operatorIndex);
IConnectableLayer* layer = m_Network->AddSplitterLayer(splitDesc, layerName.c_str());
+ TensorShape splitOutShape = TensorShape(static_cast<unsigned int>(unpackDimSizes.size()),
+ unpackDimSizes.data());
+
auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
- TensorShape outShape = TensorShape(static_cast<unsigned int>(unpackDimSizes.size()),
- unpackDimSizes.data());
+ // Reshape to remove unpacked dimension
+ unsigned int reshapedNumDimensions = inputDimSize - 1;
+ std::vector<unsigned int> reshapedDimensions(reshapedNumDimensions);
- for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
+ unsigned int reshapeIndex = 0;
+ for (unsigned int i = 0; i < inputDimSize; ++i)
{
- layer->GetOutputSlot(k).SetTensorInfo(armnn::TensorInfo(outShape,
- inputTensorInfo.GetDataType()));
+ if (i == unpackAxis)
+ {
+ continue;
+ }
+ reshapedDimensions[reshapeIndex++] = unpackDimSizes[i];
}
- auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
- RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
+ // Create reshape to remove the unpacked dimension for unpack operator of each output from Splitter.
+ for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
+ {
+ armnn::TensorInfo reshapedTensorInfo = inputTensorInfo;
+ reshapedTensorInfo.SetShape(armnn::TensorShape{ reshapedNumDimensions, reshapedDimensions.data() });
+
+ std::string reshapeLayerName = boost::str(boost::format("Reshape_for:%1%") % layer->GetName());
+ armnn::ReshapeDescriptor desc;
+ desc.m_TargetShape = reshapedTensorInfo.GetShape();
+ armnn::IConnectableLayer* reshapeLayer = m_Network->AddReshapeLayer(desc, layerName.c_str());
+
+ layer->GetOutputSlot(k).SetTensorInfo(armnn::TensorInfo(splitOutShape, inputTensorInfo.GetDataType()));
+ layer->GetOutputSlot(k).Connect(reshapeLayer->GetInputSlot(0));
+
+ reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
+
+ uint32_t reshapedOutputId = CHECKED_NON_NEGATIVE(operatorPtr->outputs[k]);
+ armnn::IOutputSlot* slot = &(reshapeLayer->GetOutputSlot(0));
+ RegisterProducerOfTensor(subgraphIndex, reshapedOutputId, slot);
+ }
}
void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex)
diff --git a/src/armnnTfLiteParser/test/Reshape.cpp b/src/armnnTfLiteParser/test/Reshape.cpp
index ef4b761945..62fbad6953 100644
--- a/src/armnnTfLiteParser/test/Reshape.cpp
+++ b/src/armnnTfLiteParser/test/Reshape.cpp
@@ -95,17 +95,17 @@ BOOST_FIXTURE_TEST_CASE(ParseReshapeWithReshapeDims, ReshapeFixtureWithReshapeDi
struct ReshapeFixtureWithReshapeDimsFlatten : ReshapeFixture
{
- ReshapeFixtureWithReshapeDimsFlatten() : ReshapeFixture("[ 3, 3 ]", "[ 1, 9 ]", "[ -1 ]") {}
+ ReshapeFixtureWithReshapeDimsFlatten() : ReshapeFixture("[ 3, 3 ]", "[ 9 ]", "[ -1 ]") {}
};
BOOST_FIXTURE_TEST_CASE(ParseReshapeWithReshapeDimsFlatten, ReshapeFixtureWithReshapeDimsFlatten)
{
SetupSingleInputSingleOutput("inputTensor", "outputTensor");
- RunTest<2, armnn::DataType::QuantisedAsymm8>(0,
+ RunTest<1, armnn::DataType::QuantisedAsymm8>(0,
{ 1, 2, 3, 4, 5, 6, 7, 8, 9 },
{ 1, 2, 3, 4, 5, 6, 7, 8, 9 });
BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
- == armnn::TensorShape({1,9})));
+ == armnn::TensorShape({9})));
}
struct ReshapeFixtureWithReshapeDimsFlattenTwoDims : ReshapeFixture