aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/Comparison.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/src/Comparison.hpp')
-rw-r--r--delegate/src/Comparison.hpp23
1 files changed, 10 insertions, 13 deletions
diff --git a/delegate/src/Comparison.hpp b/delegate/src/Comparison.hpp
index 80354e835d..688f90c597 100644
--- a/delegate/src/Comparison.hpp
+++ b/delegate/src/Comparison.hpp
@@ -57,10 +57,17 @@ TfLiteStatus VisitComparisonOperator(DelegateData& delegateData,
return kTfLiteError;
}
- const armnn::TensorInfo& inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0);
- const armnn::TensorInfo& inputTensorInfo1 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor1);
+ armnn::TensorInfo inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0);
+ armnn::TensorInfo inputTensorInfo1 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor1);
const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
+ // Check if we need to expand the dims of any of the input tensor infos.
+ // This is required for a few of the backends.
+ if(inputTensorInfo0.GetNumDimensions() != inputTensorInfo1.GetNumDimensions())
+ {
+ ExpandTensorRankToEqual(inputTensorInfo0, inputTensorInfo1);
+ }
+
armnn::ComparisonOperation comparisonOperation = armnn::ComparisonOperation::Equal;
switch(tfLiteComparisonOperatorCode)
{
@@ -122,17 +129,7 @@ TfLiteStatus VisitComparisonOperator(DelegateData& delegateData,
return kTfLiteError;
}
- auto reshapeLayer = BroadcastTensor(inputTensorInfo0,
- inputTensorInfo1,
- comparisonLayer,
- tfLiteContext,
- tfLiteNode,
- delegateData);
- if (!reshapeLayer)
- {
- return kTfLiteError;
- }
- return kTfLiteOk;
+ return Connect(comparisonLayer, tfLiteNode, delegateData);
}
} // namespace armnnDelegate