diff options
Diffstat (limited to 'src')
-rwxr-xr-x | src/armnnTfParser/TfParser.cpp | 20 |
1 files changed, 19 insertions, 1 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index b5fe6be075..0410460059 100755 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -2910,7 +2910,25 @@ ParsedTfOperationPtr TfParser::AddAdditionLayer(const tensorflow::NodeDef& nodeD input0Slot->Connect(layer->GetInputSlot(0)); input1Slot->Connect(layer->GetInputSlot(1)); - if (input0Info.GetNumDimensions() == 1 && isBiasAdd == false) + if (input0Info.GetNumDimensions() == input1Info.GetNumDimensions()) + { + const TensorShape& input0Shape = input0Info.GetShape(); + const TensorShape& input1Shape = input1Info.GetShape(); + + std::vector<unsigned int> outputShape; + outputShape.reserve(input0Shape.GetNumDimensions()); + TensorInfo outputInfo(input0Info); + + for (unsigned int i = 0; i < input0Shape.GetNumDimensions(); i++) + { + outputShape.push_back(std::max(input0Shape[i], input1Shape[i])); + } + + outputInfo.SetShape(TensorShape(input0Shape.GetNumDimensions(), outputShape.data())); + + layer->GetOutputSlot(0).SetTensorInfo(outputInfo); + } + else if (input0Info.GetNumDimensions() == 1 && isBiasAdd == false) { layer->GetOutputSlot(0).SetTensorInfo(input1Slot->GetTensorInfo()); } |