aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/ElementwiseBinary.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/src/ElementwiseBinary.hpp')
-rw-r--r--delegate/src/ElementwiseBinary.hpp15
1 files changed, 8 insertions, 7 deletions
diff --git a/delegate/src/ElementwiseBinary.hpp b/delegate/src/ElementwiseBinary.hpp
index 8096acfefb..52c6b2434b 100644
--- a/delegate/src/ElementwiseBinary.hpp
+++ b/delegate/src/ElementwiseBinary.hpp
@@ -254,6 +254,13 @@ TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData,
const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
+ // Check if we need to expand the dims of the input tensor infos.
+ // This is required for a few of the backends.
+ if(inputTensorInfo0.GetNumDimensions() != inputTensorInfo1.GetNumDimensions())
+ {
+ ExpandTensorRankToEqual(inputTensorInfo0, inputTensorInfo1);
+ }
+
auto* tfLiteNodeParameters = reinterpret_cast<TfLiteAddParams*>(tfLiteNode->builtin_data);
TfLiteFusedActivation activationType = kTfLiteActNone;
if (tfLiteNodeParameters)
@@ -363,13 +370,7 @@ TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData,
return inputsTensorsProcess;
}
- auto reshapeLayer = BroadcastTensor(inputTensorInfo0,
- inputTensorInfo1,
- elementwiseBinaryLayer,
- tfLiteContext,
- tfLiteNode,
- delegateData);
- if (!reshapeLayer)
+ if(Connect(elementwiseBinaryLayer, tfLiteNode, delegateData) != kTfLiteOk)
{
return kTfLiteError;
}