From 672de578c819a3d815e037cddfaf4dcf49d12917 Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Tue, 23 Apr 2019 15:28:06 +0100 Subject: IVGCVSW-2994 Add Reshape layer to ParseUnpack in TfLite parser to remove the unpacked dimension of each output from Splitter and correct ReshapeFixtureWithReshapeDimsFlatten test output shape Signed-off-by: Narumol Prangnawarat Change-Id: I517d315475612ac8b773930f9b58cac316fa8553 --- src/armnnTfLiteParser/TfLiteParser.cpp | 53 +++++++++++++++++++++++++++++----- src/armnnTfLiteParser/test/Reshape.cpp | 6 ++-- 2 files changed, 49 insertions(+), 10 deletions(-) diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index b7258b3ffc..44b3614bb2 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -1888,6 +1888,19 @@ void TfLiteParser::ParseUnpack(size_t subgraphIndex, size_t operatorIndex) CHECK_VALID_SIZE(inputs.size(), 1); armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); + + if (unpackAxis >= inputTensorInfo.GetNumDimensions()) + { + throw ParseException( + boost::str( + boost::format( + "The unpack axis: %1% cannot be greater than or equal to " + "the number of input dimension %2% %3%") + % unpackAxis + % inputTensorInfo.GetNumDimensions() + % CHECK_LOCATION().AsString())); + } + unsigned int unpackNum = CHECKED_NON_NEGATIVE(options->num); // If num is not defined, automatically infer from the length of the dimension axis. if(unpackNum == 0) @@ -1935,20 +1948,46 @@ void TfLiteParser::ParseUnpack(size_t subgraphIndex, size_t operatorIndex) auto layerName = boost::str(boost::format("Unpack:%1%:%2%") % subgraphIndex % operatorIndex); IConnectableLayer* layer = m_Network->AddSplitterLayer(splitDesc, layerName.c_str()); + TensorShape splitOutShape = TensorShape(static_cast(unpackDimSizes.size()), + unpackDimSizes.data()); + auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]}); - TensorShape outShape = TensorShape(static_cast(unpackDimSizes.size()), - unpackDimSizes.data()); + // Reshape to remove unpacked dimension + unsigned int reshapedNumDimensions = inputDimSize - 1; + std::vector reshapedDimensions(reshapedNumDimensions); - for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k) + unsigned int reshapeIndex = 0; + for (unsigned int i = 0; i < inputDimSize; ++i) { - layer->GetOutputSlot(k).SetTensorInfo(armnn::TensorInfo(outShape, - inputTensorInfo.GetDataType())); + if (i == unpackAxis) + { + continue; + } + reshapedDimensions[reshapeIndex++] = unpackDimSizes[i]; } - auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex)); - RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes); + // Create reshape to remove the unpacked dimension for unpack operator of each output from Splitter. + for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k) + { + armnn::TensorInfo reshapedTensorInfo = inputTensorInfo; + reshapedTensorInfo.SetShape(armnn::TensorShape{ reshapedNumDimensions, reshapedDimensions.data() }); + + std::string reshapeLayerName = boost::str(boost::format("Reshape_for:%1%") % layer->GetName()); + armnn::ReshapeDescriptor desc; + desc.m_TargetShape = reshapedTensorInfo.GetShape(); + armnn::IConnectableLayer* reshapeLayer = m_Network->AddReshapeLayer(desc, layerName.c_str()); + + layer->GetOutputSlot(k).SetTensorInfo(armnn::TensorInfo(splitOutShape, inputTensorInfo.GetDataType())); + layer->GetOutputSlot(k).Connect(reshapeLayer->GetInputSlot(0)); + + reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo); + + uint32_t reshapedOutputId = CHECKED_NON_NEGATIVE(operatorPtr->outputs[k]); + armnn::IOutputSlot* slot = &(reshapeLayer->GetOutputSlot(0)); + RegisterProducerOfTensor(subgraphIndex, reshapedOutputId, slot); + } } void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex) diff --git a/src/armnnTfLiteParser/test/Reshape.cpp b/src/armnnTfLiteParser/test/Reshape.cpp index ef4b761945..62fbad6953 100644 --- a/src/armnnTfLiteParser/test/Reshape.cpp +++ b/src/armnnTfLiteParser/test/Reshape.cpp @@ -95,17 +95,17 @@ BOOST_FIXTURE_TEST_CASE(ParseReshapeWithReshapeDims, ReshapeFixtureWithReshapeDi struct ReshapeFixtureWithReshapeDimsFlatten : ReshapeFixture { - ReshapeFixtureWithReshapeDimsFlatten() : ReshapeFixture("[ 3, 3 ]", "[ 1, 9 ]", "[ -1 ]") {} + ReshapeFixtureWithReshapeDimsFlatten() : ReshapeFixture("[ 3, 3 ]", "[ 9 ]", "[ -1 ]") {} }; BOOST_FIXTURE_TEST_CASE(ParseReshapeWithReshapeDimsFlatten, ReshapeFixtureWithReshapeDimsFlatten) { SetupSingleInputSingleOutput("inputTensor", "outputTensor"); - RunTest<2, armnn::DataType::QuantisedAsymm8>(0, + RunTest<1, armnn::DataType::QuantisedAsymm8>(0, { 1, 2, 3, 4, 5, 6, 7, 8, 9 }, { 1, 2, 3, 4, 5, 6, 7, 8, 9 }); BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape() - == armnn::TensorShape({1,9}))); + == armnn::TensorShape({9}))); } struct ReshapeFixtureWithReshapeDimsFlattenTwoDims : ReshapeFixture -- cgit v1.2.1