aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBruno Goncalves <bruno.slackware@gmail.com>2018-12-27 14:20:35 -0200
committerMatthew Bentham <matthew.bentham@arm.com>2019-01-22 17:10:06 +0000
commit9c761a6335f58901bef01b33f1127f9fae3b2bf3 (patch)
tree114970cbde18a4facdc655a2bb917fd06390f168
parent33f8e3b6c71070fd867809ca6934069a950081dc (diff)
downloadarmnn-9c761a6335f58901bef01b33f1127f9fae3b2bf3.tar.gz
Added AddBroadcastReshapeLayer method to TfLiteParser
Change-Id: I6027f6dcdb3ed23505f0a9c780bd3e3d45d3daff
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp81
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.hpp4
-rw-r--r--src/armnnTfLiteParser/test/Multiplication.cpp36
3 files changed, 119 insertions, 2 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 0e8d3c5b68..c45e794274 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -463,6 +463,63 @@ void TfLiteParser::ResetParser()
m_SubgraphConnections.clear();
}
+void TfLiteParser::AddBroadcastReshapeLayer(size_t subgraphIndex,
+ size_t operatorIndex,
+ IConnectableLayer *layer)
+{
+ CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
+ BOOST_ASSERT(layer != nullptr);
+
+ const auto & subGraphPtr = m_Model->subgraphs[subgraphIndex];
+ const auto & operatorPtr = subGraphPtr->operators[operatorIndex];
+
+ BOOST_ASSERT(operatorPtr->inputs.size() > 1);
+
+ uint32_t reshapedInputId = CHECKED_NON_NEGATIVE(operatorPtr->inputs[0]);
+ TensorRawPtr tensorPtr = subGraphPtr->tensors[reshapedInputId].get();
+ uint32_t inputId = CHECKED_NON_NEGATIVE(operatorPtr->inputs[1]);
+ TensorRawPtr tensorPtr1 = subGraphPtr->tensors[inputId].get();
+
+ armnn::TensorInfo reshapedTensorInfo = ToTensorInfo(tensorPtr);
+ armnn::TensorInfo inputTensorInfo = ToTensorInfo(tensorPtr1);
+
+ if (inputTensorInfo.GetNumDimensions() < reshapedTensorInfo.GetNumDimensions())
+ {
+ uint32_t id = reshapedInputId;
+ reshapedInputId = inputId;
+ inputId = id;
+
+ reshapedTensorInfo = ToTensorInfo(tensorPtr1);
+ inputTensorInfo = ToTensorInfo(tensorPtr);
+ }
+
+ uint32_t numDimensions = inputTensorInfo.GetNumDimensions();
+
+ std::vector<unsigned> reshapedDim;
+ for (unsigned int i = 0; i < reshapedTensorInfo.GetNumDimensions(); ++i)
+ {
+ reshapedDim.push_back(reshapedTensorInfo.GetShape()[i]);
+ }
+
+ std::vector<unsigned int> reshapedDimensions(numDimensions, 1);
+ std::copy_backward (reshapedDim.begin(), reshapedDim.end(), reshapedDimensions.end());
+
+ reshapedTensorInfo.SetShape(armnn::TensorShape{ numDimensions, reshapedDimensions.data() });
+
+ std::string layerName = boost::str(boost::format("Reshape_for:%1%") % layer->GetName());
+ armnn::ReshapeDescriptor desc;
+ desc.m_TargetShape = reshapedTensorInfo.GetShape();
+ armnn::IConnectableLayer* reshapeLayer = m_Network->AddReshapeLayer(desc, layerName.c_str());
+
+ reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
+ reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
+
+ RegisterInputSlots(subgraphIndex, operatorIndex, reshapeLayer, {reshapedInputId});
+
+ armnn::IInputSlot* input1Slot = &(layer->GetInputSlot(1));
+ RegisterConsumerOfTensor(subgraphIndex, inputId, input1Slot);
+}
+
INetworkPtr TfLiteParser::CreateNetworkFromBinaryFile(const char* graphFile)
{
ResetParser();
@@ -1008,6 +1065,9 @@ void TfLiteParser::ParseAdd(size_t subgraphIndex, size_t operatorIndex)
auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
CHECK_VALID_SIZE(outputs.size(), 1);
+ armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
+ armnn::TensorInfo input1TensorInfo = ToTensorInfo(inputs[1]);
+
auto layerName = boost::str(boost::format("Add:%1%:%2%") % subgraphIndex % operatorIndex);
IConnectableLayer* layer = m_Network->AddAdditionLayer(layerName.c_str());
@@ -1015,7 +1075,14 @@ void TfLiteParser::ParseAdd(size_t subgraphIndex, size_t operatorIndex)
layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
- RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
+ if (inputTensorInfo.GetNumDimensions() != input1TensorInfo.GetNumDimensions())
+ {
+ AddBroadcastReshapeLayer(subgraphIndex, operatorIndex, layer);
+ }
+ else
+ {
+ RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
+ }
layer = AddFusedActivationLayer(layer, 0, options->fused_activation_function);
@@ -1036,6 +1103,9 @@ void TfLiteParser::ParseMul(size_t subgraphIndex, size_t operatorIndex)
auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
CHECK_VALID_SIZE(outputs.size(), 1);
+ armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
+ armnn::TensorInfo input1TensorInfo = ToTensorInfo(inputs[1]);
+
auto layerName = boost::str(boost::format("Mul:%1%:%2%") % subgraphIndex % operatorIndex);
IConnectableLayer* layer = m_Network->AddMultiplicationLayer(layerName.c_str());
@@ -1043,7 +1113,14 @@ void TfLiteParser::ParseMul(size_t subgraphIndex, size_t operatorIndex)
layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
- RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
+ if (inputTensorInfo.GetNumDimensions() != input1TensorInfo.GetNumDimensions())
+ {
+ AddBroadcastReshapeLayer(subgraphIndex, operatorIndex, layer);
+ }
+ else
+ {
+ RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
+ }
layer = AddFusedActivationLayer(layer, 0, options->fused_activation_function);
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp
index 6c264372ba..34ae07f392 100644
--- a/src/armnnTfLiteParser/TfLiteParser.hpp
+++ b/src/armnnTfLiteParser/TfLiteParser.hpp
@@ -124,6 +124,10 @@ private:
void ResetParser();
+ void AddBroadcastReshapeLayer(size_t subgraphIndex,
+ size_t operatorIndex,
+ armnn::IConnectableLayer* layer);
+
/// Attach an activation layer to the one passed as a parameter
armnn::IConnectableLayer* AddFusedActivationLayer(armnn::IConnectableLayer* layer,
unsigned int outputSlot,
diff --git a/src/armnnTfLiteParser/test/Multiplication.cpp b/src/armnnTfLiteParser/test/Multiplication.cpp
index f7e2edd546..dabf868559 100644
--- a/src/armnnTfLiteParser/test/Multiplication.cpp
+++ b/src/armnnTfLiteParser/test/Multiplication.cpp
@@ -108,4 +108,40 @@ BOOST_FIXTURE_TEST_CASE(ParseMultiplication, SimpleMultiplicationFixture)
45.0f, 50.0f, 55.0f } } });
}
+struct MultiplicationBroadcastFixture4D1D : public MultiplicationFixture
+{
+ MultiplicationBroadcastFixture4D1D() : MultiplicationFixture("[ 1, 2, 2, 3 ]", "[ 1 ]", "[ 1, 2, 2, 3 ]") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseMultiplicationBroadcast4D1D, MultiplicationBroadcastFixture4D1D)
+{
+ RunTest<4, float>(0, {{ "inputTensor1", { 0.0f, 1.0f, 2.0f,
+ 3.0f, 4.0f, 5.0f,
+ 6.0f, 7.0f, 8.0f,
+ 9.0f, 10.0f, 11.0f } },
+ { "inputTensor2", { 5.0f } } },
+ {{ "outputTensor", { 0.0f, 5.0f, 10.0f,
+ 15.0f, 20.0f, 25.0f,
+ 30.0f, 35.0f, 40.0f,
+ 45.0f, 50.0f, 55.0f } } });
+}
+
+struct MultiplicationBroadcastFixture1D4D : public MultiplicationFixture
+{
+ MultiplicationBroadcastFixture1D4D() : MultiplicationFixture("[ 1 ]", "[ 1, 2, 2, 3 ]", "[ 1, 2, 2, 3 ]") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseMultiplicationBroadcast1D4D, MultiplicationBroadcastFixture1D4D)
+{
+ RunTest<4, float>(0, {{ "inputTensor1", { 3.0f } },
+ { "inputTensor2", { 0.0f, 1.0f, 2.0f,
+ 3.0f, 4.0f, 5.0f,
+ 6.0f, 7.0f, 8.0f,
+ 9.0f, 10.0f, 11.0f } } },
+ {{ "outputTensor", { 0.0f, 3.0f, 6.0f,
+ 9.0f, 12.0f, 15.0f,
+ 18.0f, 21.0f, 24.0f,
+ 27.0f, 30.0f, 33.0f } } });
+}
+
BOOST_AUTO_TEST_SUITE_END()