aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xsrc/armnnTfParser/TfParser.cpp20
1 files changed, 19 insertions, 1 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index b5fe6be075..0410460059 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -2910,7 +2910,25 @@ ParsedTfOperationPtr TfParser::AddAdditionLayer(const tensorflow::NodeDef& nodeD
input0Slot->Connect(layer->GetInputSlot(0));
input1Slot->Connect(layer->GetInputSlot(1));
- if (input0Info.GetNumDimensions() == 1 && isBiasAdd == false)
+ if (input0Info.GetNumDimensions() == input1Info.GetNumDimensions())
+ {
+ const TensorShape& input0Shape = input0Info.GetShape();
+ const TensorShape& input1Shape = input1Info.GetShape();
+
+ std::vector<unsigned int> outputShape;
+ outputShape.reserve(input0Shape.GetNumDimensions());
+ TensorInfo outputInfo(input0Info);
+
+ for (unsigned int i = 0; i < input0Shape.GetNumDimensions(); i++)
+ {
+ outputShape.push_back(std::max(input0Shape[i], input1Shape[i]));
+ }
+
+ outputInfo.SetShape(TensorShape(input0Shape.GetNumDimensions(), outputShape.data()));
+
+ layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+ }
+ else if (input0Info.GetNumDimensions() == 1 && isBiasAdd == false)
{
layer->GetOutputSlot(0).SetTensorInfo(input1Slot->GetTensorInfo());
}