From 9c761a6335f58901bef01b33f1127f9fae3b2bf3 Mon Sep 17 00:00:00 2001 From: Bruno Goncalves Date: Thu, 27 Dec 2018 14:20:35 -0200 Subject: Added AddBroadcastReshapeLayer method to TfLiteParser Change-Id: I6027f6dcdb3ed23505f0a9c780bd3e3d45d3daff --- src/armnnTfLiteParser/TfLiteParser.cpp | 81 ++++++++++++++++++++++++++- src/armnnTfLiteParser/TfLiteParser.hpp | 4 ++ src/armnnTfLiteParser/test/Multiplication.cpp | 36 ++++++++++++ 3 files changed, 119 insertions(+), 2 deletions(-) diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 0e8d3c5b68..c45e794274 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -463,6 +463,63 @@ void TfLiteParser::ResetParser() m_SubgraphConnections.clear(); } +void TfLiteParser::AddBroadcastReshapeLayer(size_t subgraphIndex, + size_t operatorIndex, + IConnectableLayer *layer) +{ + CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); + BOOST_ASSERT(layer != nullptr); + + const auto & subGraphPtr = m_Model->subgraphs[subgraphIndex]; + const auto & operatorPtr = subGraphPtr->operators[operatorIndex]; + + BOOST_ASSERT(operatorPtr->inputs.size() > 1); + + uint32_t reshapedInputId = CHECKED_NON_NEGATIVE(operatorPtr->inputs[0]); + TensorRawPtr tensorPtr = subGraphPtr->tensors[reshapedInputId].get(); + uint32_t inputId = CHECKED_NON_NEGATIVE(operatorPtr->inputs[1]); + TensorRawPtr tensorPtr1 = subGraphPtr->tensors[inputId].get(); + + armnn::TensorInfo reshapedTensorInfo = ToTensorInfo(tensorPtr); + armnn::TensorInfo inputTensorInfo = ToTensorInfo(tensorPtr1); + + if (inputTensorInfo.GetNumDimensions() < reshapedTensorInfo.GetNumDimensions()) + { + uint32_t id = reshapedInputId; + reshapedInputId = inputId; + inputId = id; + + reshapedTensorInfo = ToTensorInfo(tensorPtr1); + inputTensorInfo = ToTensorInfo(tensorPtr); + } + + uint32_t numDimensions = inputTensorInfo.GetNumDimensions(); + + std::vector reshapedDim; + for (unsigned int i = 0; i < reshapedTensorInfo.GetNumDimensions(); ++i) + { + reshapedDim.push_back(reshapedTensorInfo.GetShape()[i]); + } + + std::vector reshapedDimensions(numDimensions, 1); + std::copy_backward (reshapedDim.begin(), reshapedDim.end(), reshapedDimensions.end()); + + reshapedTensorInfo.SetShape(armnn::TensorShape{ numDimensions, reshapedDimensions.data() }); + + std::string layerName = boost::str(boost::format("Reshape_for:%1%") % layer->GetName()); + armnn::ReshapeDescriptor desc; + desc.m_TargetShape = reshapedTensorInfo.GetShape(); + armnn::IConnectableLayer* reshapeLayer = m_Network->AddReshapeLayer(desc, layerName.c_str()); + + reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo); + reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0)); + + RegisterInputSlots(subgraphIndex, operatorIndex, reshapeLayer, {reshapedInputId}); + + armnn::IInputSlot* input1Slot = &(layer->GetInputSlot(1)); + RegisterConsumerOfTensor(subgraphIndex, inputId, input1Slot); +} + INetworkPtr TfLiteParser::CreateNetworkFromBinaryFile(const char* graphFile) { ResetParser(); @@ -1008,6 +1065,9 @@ void TfLiteParser::ParseAdd(size_t subgraphIndex, size_t operatorIndex) auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); CHECK_VALID_SIZE(outputs.size(), 1); + armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); + armnn::TensorInfo input1TensorInfo = ToTensorInfo(inputs[1]); + auto layerName = boost::str(boost::format("Add:%1%:%2%") % subgraphIndex % operatorIndex); IConnectableLayer* layer = m_Network->AddAdditionLayer(layerName.c_str()); @@ -1015,7 +1075,14 @@ void TfLiteParser::ParseAdd(size_t subgraphIndex, size_t operatorIndex) layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); - RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]}); + if (inputTensorInfo.GetNumDimensions() != input1TensorInfo.GetNumDimensions()) + { + AddBroadcastReshapeLayer(subgraphIndex, operatorIndex, layer); + } + else + { + RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]}); + } layer = AddFusedActivationLayer(layer, 0, options->fused_activation_function); @@ -1036,6 +1103,9 @@ void TfLiteParser::ParseMul(size_t subgraphIndex, size_t operatorIndex) auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); CHECK_VALID_SIZE(outputs.size(), 1); + armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); + armnn::TensorInfo input1TensorInfo = ToTensorInfo(inputs[1]); + auto layerName = boost::str(boost::format("Mul:%1%:%2%") % subgraphIndex % operatorIndex); IConnectableLayer* layer = m_Network->AddMultiplicationLayer(layerName.c_str()); @@ -1043,7 +1113,14 @@ void TfLiteParser::ParseMul(size_t subgraphIndex, size_t operatorIndex) layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); - RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]}); + if (inputTensorInfo.GetNumDimensions() != input1TensorInfo.GetNumDimensions()) + { + AddBroadcastReshapeLayer(subgraphIndex, operatorIndex, layer); + } + else + { + RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]}); + } layer = AddFusedActivationLayer(layer, 0, options->fused_activation_function); diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp index 6c264372ba..34ae07f392 100644 --- a/src/armnnTfLiteParser/TfLiteParser.hpp +++ b/src/armnnTfLiteParser/TfLiteParser.hpp @@ -124,6 +124,10 @@ private: void ResetParser(); + void AddBroadcastReshapeLayer(size_t subgraphIndex, + size_t operatorIndex, + armnn::IConnectableLayer* layer); + /// Attach an activation layer to the one passed as a parameter armnn::IConnectableLayer* AddFusedActivationLayer(armnn::IConnectableLayer* layer, unsigned int outputSlot, diff --git a/src/armnnTfLiteParser/test/Multiplication.cpp b/src/armnnTfLiteParser/test/Multiplication.cpp index f7e2edd546..dabf868559 100644 --- a/src/armnnTfLiteParser/test/Multiplication.cpp +++ b/src/armnnTfLiteParser/test/Multiplication.cpp @@ -108,4 +108,40 @@ BOOST_FIXTURE_TEST_CASE(ParseMultiplication, SimpleMultiplicationFixture) 45.0f, 50.0f, 55.0f } } }); } +struct MultiplicationBroadcastFixture4D1D : public MultiplicationFixture +{ + MultiplicationBroadcastFixture4D1D() : MultiplicationFixture("[ 1, 2, 2, 3 ]", "[ 1 ]", "[ 1, 2, 2, 3 ]") {} +}; + +BOOST_FIXTURE_TEST_CASE(ParseMultiplicationBroadcast4D1D, MultiplicationBroadcastFixture4D1D) +{ + RunTest<4, float>(0, {{ "inputTensor1", { 0.0f, 1.0f, 2.0f, + 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f } }, + { "inputTensor2", { 5.0f } } }, + {{ "outputTensor", { 0.0f, 5.0f, 10.0f, + 15.0f, 20.0f, 25.0f, + 30.0f, 35.0f, 40.0f, + 45.0f, 50.0f, 55.0f } } }); +} + +struct MultiplicationBroadcastFixture1D4D : public MultiplicationFixture +{ + MultiplicationBroadcastFixture1D4D() : MultiplicationFixture("[ 1 ]", "[ 1, 2, 2, 3 ]", "[ 1, 2, 2, 3 ]") {} +}; + +BOOST_FIXTURE_TEST_CASE(ParseMultiplicationBroadcast1D4D, MultiplicationBroadcastFixture1D4D) +{ + RunTest<4, float>(0, {{ "inputTensor1", { 3.0f } }, + { "inputTensor2", { 0.0f, 1.0f, 2.0f, + 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f } } }, + {{ "outputTensor", { 0.0f, 3.0f, 6.0f, + 9.0f, 12.0f, 15.0f, + 18.0f, 21.0f, 24.0f, + 27.0f, 30.0f, 33.0f } } }); +} + BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1