aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
authorBruno Goncalves <bruno.slackware@gmail.com>2018-12-27 14:20:35 -0200
committerMatthew Bentham <matthew.bentham@arm.com>2019-01-22 17:10:06 +0000
commit9c761a6335f58901bef01b33f1127f9fae3b2bf3 (patch)
tree114970cbde18a4facdc655a2bb917fd06390f168 /src/armnnTfLiteParser/TfLiteParser.cpp
parent33f8e3b6c71070fd867809ca6934069a950081dc (diff)
downloadarmnn-9c761a6335f58901bef01b33f1127f9fae3b2bf3.tar.gz
Added AddBroadcastReshapeLayer method to TfLiteParser
Change-Id: I6027f6dcdb3ed23505f0a9c780bd3e3d45d3daff
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp81
1 files changed, 79 insertions, 2 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 0e8d3c5b68..c45e794274 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -463,6 +463,63 @@ void TfLiteParser::ResetParser()
m_SubgraphConnections.clear();
}
+void TfLiteParser::AddBroadcastReshapeLayer(size_t subgraphIndex,
+ size_t operatorIndex,
+ IConnectableLayer *layer)
+{
+ CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
+ BOOST_ASSERT(layer != nullptr);
+
+ const auto & subGraphPtr = m_Model->subgraphs[subgraphIndex];
+ const auto & operatorPtr = subGraphPtr->operators[operatorIndex];
+
+ BOOST_ASSERT(operatorPtr->inputs.size() > 1);
+
+ uint32_t reshapedInputId = CHECKED_NON_NEGATIVE(operatorPtr->inputs[0]);
+ TensorRawPtr tensorPtr = subGraphPtr->tensors[reshapedInputId].get();
+ uint32_t inputId = CHECKED_NON_NEGATIVE(operatorPtr->inputs[1]);
+ TensorRawPtr tensorPtr1 = subGraphPtr->tensors[inputId].get();
+
+ armnn::TensorInfo reshapedTensorInfo = ToTensorInfo(tensorPtr);
+ armnn::TensorInfo inputTensorInfo = ToTensorInfo(tensorPtr1);
+
+ if (inputTensorInfo.GetNumDimensions() < reshapedTensorInfo.GetNumDimensions())
+ {
+ uint32_t id = reshapedInputId;
+ reshapedInputId = inputId;
+ inputId = id;
+
+ reshapedTensorInfo = ToTensorInfo(tensorPtr1);
+ inputTensorInfo = ToTensorInfo(tensorPtr);
+ }
+
+ uint32_t numDimensions = inputTensorInfo.GetNumDimensions();
+
+ std::vector<unsigned> reshapedDim;
+ for (unsigned int i = 0; i < reshapedTensorInfo.GetNumDimensions(); ++i)
+ {
+ reshapedDim.push_back(reshapedTensorInfo.GetShape()[i]);
+ }
+
+ std::vector<unsigned int> reshapedDimensions(numDimensions, 1);
+ std::copy_backward (reshapedDim.begin(), reshapedDim.end(), reshapedDimensions.end());
+
+ reshapedTensorInfo.SetShape(armnn::TensorShape{ numDimensions, reshapedDimensions.data() });
+
+ std::string layerName = boost::str(boost::format("Reshape_for:%1%") % layer->GetName());
+ armnn::ReshapeDescriptor desc;
+ desc.m_TargetShape = reshapedTensorInfo.GetShape();
+ armnn::IConnectableLayer* reshapeLayer = m_Network->AddReshapeLayer(desc, layerName.c_str());
+
+ reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
+ reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
+
+ RegisterInputSlots(subgraphIndex, operatorIndex, reshapeLayer, {reshapedInputId});
+
+ armnn::IInputSlot* input1Slot = &(layer->GetInputSlot(1));
+ RegisterConsumerOfTensor(subgraphIndex, inputId, input1Slot);
+}
+
INetworkPtr TfLiteParser::CreateNetworkFromBinaryFile(const char* graphFile)
{
ResetParser();
@@ -1008,6 +1065,9 @@ void TfLiteParser::ParseAdd(size_t subgraphIndex, size_t operatorIndex)
auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
CHECK_VALID_SIZE(outputs.size(), 1);
+ armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
+ armnn::TensorInfo input1TensorInfo = ToTensorInfo(inputs[1]);
+
auto layerName = boost::str(boost::format("Add:%1%:%2%") % subgraphIndex % operatorIndex);
IConnectableLayer* layer = m_Network->AddAdditionLayer(layerName.c_str());
@@ -1015,7 +1075,14 @@ void TfLiteParser::ParseAdd(size_t subgraphIndex, size_t operatorIndex)
layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
- RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
+ if (inputTensorInfo.GetNumDimensions() != input1TensorInfo.GetNumDimensions())
+ {
+ AddBroadcastReshapeLayer(subgraphIndex, operatorIndex, layer);
+ }
+ else
+ {
+ RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
+ }
layer = AddFusedActivationLayer(layer, 0, options->fused_activation_function);
@@ -1036,6 +1103,9 @@ void TfLiteParser::ParseMul(size_t subgraphIndex, size_t operatorIndex)
auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
CHECK_VALID_SIZE(outputs.size(), 1);
+ armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
+ armnn::TensorInfo input1TensorInfo = ToTensorInfo(inputs[1]);
+
auto layerName = boost::str(boost::format("Mul:%1%:%2%") % subgraphIndex % operatorIndex);
IConnectableLayer* layer = m_Network->AddMultiplicationLayer(layerName.c_str());
@@ -1043,7 +1113,14 @@ void TfLiteParser::ParseMul(size_t subgraphIndex, size_t operatorIndex)
layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
- RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
+ if (inputTensorInfo.GetNumDimensions() != input1TensorInfo.GetNumDimensions())
+ {
+ AddBroadcastReshapeLayer(subgraphIndex, operatorIndex, layer);
+ }
+ else
+ {
+ RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
+ }
layer = AddFusedActivationLayer(layer, 0, options->fused_activation_function);