diff options
Diffstat (limited to 'delegate/src/ElementwiseBinary.hpp')
-rw-r--r-- | delegate/src/ElementwiseBinary.hpp | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/delegate/src/ElementwiseBinary.hpp b/delegate/src/ElementwiseBinary.hpp index 8096acfefb..52c6b2434b 100644 --- a/delegate/src/ElementwiseBinary.hpp +++ b/delegate/src/ElementwiseBinary.hpp @@ -254,6 +254,13 @@ TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData, const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true); + // Check if we need to expand the dims of the input tensor infos. + // This is required for a few of the backends. + if(inputTensorInfo0.GetNumDimensions() != inputTensorInfo1.GetNumDimensions()) + { + ExpandTensorRankToEqual(inputTensorInfo0, inputTensorInfo1); + } + auto* tfLiteNodeParameters = reinterpret_cast<TfLiteAddParams*>(tfLiteNode->builtin_data); TfLiteFusedActivation activationType = kTfLiteActNone; if (tfLiteNodeParameters) @@ -363,13 +370,7 @@ TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData, return inputsTensorsProcess; } - auto reshapeLayer = BroadcastTensor(inputTensorInfo0, - inputTensorInfo1, - elementwiseBinaryLayer, - tfLiteContext, - tfLiteNode, - delegateData); - if (!reshapeLayer) + if(Connect(elementwiseBinaryLayer, tfLiteNode, delegateData) != kTfLiteOk) { return kTfLiteError; } |