From 501f4d4efff78f890602d062709126f9a294a352 Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Wed, 24 Apr 2019 15:52:20 +0100 Subject: IVGCVSW-2996 Add Reshape layer to ParseFullyConnected in TfLite parser when input is > 2D to flatten the input to 2D [batch_size, input_size] Change-Id: Id9d9ff996225c7d0938204ae0ceb330a11e264f5 Signed-off-by: Narumol Prangnawarat --- src/armnnTfLiteParser/TfLiteParser.cpp | 50 ++++++++++++++++++++++++--- src/armnnTfLiteParser/test/FullyConnected.cpp | 20 +++++++++++ 2 files changed, 65 insertions(+), 5 deletions(-) diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 44b3614bb2..57333439aa 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -1780,17 +1780,57 @@ void TfLiteParser::ParseFullyConnected(size_t subgraphIndex, size_t operatorInde } BOOST_ASSERT(layer != nullptr); - armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); - layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); - // register the input connection slot for the layer - // only the tensors for the inputs are relevant, exclude the const tensors auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); - RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]}); + + if (inputTensorInfo.GetNumDimensions() > 2) + { + // Add reshape to flatten to 2D [batch_size, input_size], + // where "input_size" corresponds to the number of inputs to the layer, + // matching the second dimension of weights, + // and "batch_size" is calculated by dividing the number of elements by "input_size". + std::vector reshapedDimensions(2); + reshapedDimensions[1] = filterTensorInfo.GetShape()[1]; + reshapedDimensions[0] = inputTensorInfo.GetNumElements() / reshapedDimensions[1]; + + if (inputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0) + { + throw ParseException( + boost::str( + boost::format( + "Failed to deduce input tensor shape from filter size %1%") + % reshapedDimensions[1] + % CHECK_LOCATION().AsString())); + } + + armnn::TensorInfo reshapedTensorInfo = ToTensorInfo(inputs[0]); + reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() }); + + std::string reshapeLayerName = boost::str(boost::format("Reshape_for:%1%") % layer->GetName()); + armnn::ReshapeDescriptor desc; + desc.m_TargetShape = reshapedTensorInfo.GetShape(); + armnn::IConnectableLayer* reshapeLayer = m_Network->AddReshapeLayer(desc, layerName.c_str()); + + reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo); + reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0)); + + RegisterInputSlots(subgraphIndex, operatorIndex, reshapeLayer, {inputTensorIndexes[0]}); + } + else + { + // register the input connection slot for the layer + // only the tensors for the inputs are relevant, exclude the const tensors + RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]}); + } + + armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); // we need to add the activation layer and fortunately we don't need to care about the data layout armnn::IConnectableLayer* fusedActivationLayer = AddFusedActivationLayer(layer, 0, options->fused_activation_function); + // register the output connection slots for the layer, connections are made after all layers have been created auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex)); RegisterOutputSlots(subgraphIndex, operatorIndex, fusedActivationLayer, {outputTensorIndexes[0]}); diff --git a/src/armnnTfLiteParser/test/FullyConnected.cpp b/src/armnnTfLiteParser/test/FullyConnected.cpp index 7ee64a476e..54d7bcb1dc 100644 --- a/src/armnnTfLiteParser/test/FullyConnected.cpp +++ b/src/armnnTfLiteParser/test/FullyConnected.cpp @@ -151,4 +151,24 @@ BOOST_FIXTURE_TEST_CASE(ParseFullyConnectedWithBias, FullyConnectedWithBiasFixtu { (400+10)/2 }); } +struct FullyConnectedWithBiasMultipleOutputsFixture : FullyConnectedFixture +{ + FullyConnectedWithBiasMultipleOutputsFixture() + : FullyConnectedFixture("[ 1, 4, 2, 1 ]", // inputShape + "[ 2, 1 ]", // outputShape + "[ 1, 4 ]", // filterShape + "[ 2, 3, 4, 5 ]", // filterData + "[ 1 ]", // biasShape + "[ 10, 0, 0, 0 ]" ) // biasData + {} +}; + +BOOST_FIXTURE_TEST_CASE(FullyConnectedWithBiasMultipleOutputs, FullyConnectedWithBiasMultipleOutputsFixture) +{ + RunTest<2, armnn::DataType::QuantisedAsymm8>( + 0, + { 1, 2, 3, 4, 10, 20, 30, 40 }, + { (40+10)/2, (400+10)/2 }); +} + BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1