aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2019-04-24 15:52:20 +0100
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-04-29 08:36:55 +0000
commit501f4d4efff78f890602d062709126f9a294a352 (patch)
tree259dcc6bb6fcd895841c276ebc5025c756f656e3
parentc01b39146d4b67159ef95a2ac2f88f6536320890 (diff)
downloadarmnn-501f4d4efff78f890602d062709126f9a294a352.tar.gz
IVGCVSW-2996 Add Reshape layer to ParseFullyConnected in TfLite parser
when input is > 2D to flatten the input to 2D [batch_size, input_size] Change-Id: Id9d9ff996225c7d0938204ae0ceb330a11e264f5 Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp50
-rw-r--r--src/armnnTfLiteParser/test/FullyConnected.cpp20
2 files changed, 65 insertions, 5 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 44b3614bb2..57333439aa 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -1780,17 +1780,57 @@ void TfLiteParser::ParseFullyConnected(size_t subgraphIndex, size_t operatorInde
}
BOOST_ASSERT(layer != nullptr);
- armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
- layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+ armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
- // register the input connection slot for the layer
- // only the tensors for the inputs are relevant, exclude the const tensors
auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
- RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
+
+ if (inputTensorInfo.GetNumDimensions() > 2)
+ {
+ // Add reshape to flatten to 2D [batch_size, input_size],
+ // where "input_size" corresponds to the number of inputs to the layer,
+ // matching the second dimension of weights,
+ // and "batch_size" is calculated by dividing the number of elements by "input_size".
+ std::vector<unsigned int> reshapedDimensions(2);
+ reshapedDimensions[1] = filterTensorInfo.GetShape()[1];
+ reshapedDimensions[0] = inputTensorInfo.GetNumElements() / reshapedDimensions[1];
+
+ if (inputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
+ {
+ throw ParseException(
+ boost::str(
+ boost::format(
+ "Failed to deduce input tensor shape from filter size %1%")
+ % reshapedDimensions[1]
+ % CHECK_LOCATION().AsString()));
+ }
+
+ armnn::TensorInfo reshapedTensorInfo = ToTensorInfo(inputs[0]);
+ reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
+
+ std::string reshapeLayerName = boost::str(boost::format("Reshape_for:%1%") % layer->GetName());
+ armnn::ReshapeDescriptor desc;
+ desc.m_TargetShape = reshapedTensorInfo.GetShape();
+ armnn::IConnectableLayer* reshapeLayer = m_Network->AddReshapeLayer(desc, layerName.c_str());
+
+ reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
+ reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
+
+ RegisterInputSlots(subgraphIndex, operatorIndex, reshapeLayer, {inputTensorIndexes[0]});
+ }
+ else
+ {
+ // register the input connection slot for the layer
+ // only the tensors for the inputs are relevant, exclude the const tensors
+ RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
+ }
+
+ armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
// we need to add the activation layer and fortunately we don't need to care about the data layout
armnn::IConnectableLayer* fusedActivationLayer = AddFusedActivationLayer(layer, 0,
options->fused_activation_function);
+
// register the output connection slots for the layer, connections are made after all layers have been created
auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
RegisterOutputSlots(subgraphIndex, operatorIndex, fusedActivationLayer, {outputTensorIndexes[0]});
diff --git a/src/armnnTfLiteParser/test/FullyConnected.cpp b/src/armnnTfLiteParser/test/FullyConnected.cpp
index 7ee64a476e..54d7bcb1dc 100644
--- a/src/armnnTfLiteParser/test/FullyConnected.cpp
+++ b/src/armnnTfLiteParser/test/FullyConnected.cpp
@@ -151,4 +151,24 @@ BOOST_FIXTURE_TEST_CASE(ParseFullyConnectedWithBias, FullyConnectedWithBiasFixtu
{ (400+10)/2 });
}
+struct FullyConnectedWithBiasMultipleOutputsFixture : FullyConnectedFixture
+{
+ FullyConnectedWithBiasMultipleOutputsFixture()
+ : FullyConnectedFixture("[ 1, 4, 2, 1 ]", // inputShape
+ "[ 2, 1 ]", // outputShape
+ "[ 1, 4 ]", // filterShape
+ "[ 2, 3, 4, 5 ]", // filterData
+ "[ 1 ]", // biasShape
+ "[ 10, 0, 0, 0 ]" ) // biasData
+ {}
+};
+
+BOOST_FIXTURE_TEST_CASE(FullyConnectedWithBiasMultipleOutputs, FullyConnectedWithBiasMultipleOutputsFixture)
+{
+ RunTest<2, armnn::DataType::QuantisedAsymm8>(
+ 0,
+ { 1, 2, 3, 4, 10, 20, 30, 40 },
+ { (40+10)/2, (400+10)/2 });
+}
+
BOOST_AUTO_TEST_SUITE_END()