diff options
Diffstat (limited to 'delegate/src/Comparison.hpp')
-rw-r--r-- | delegate/src/Comparison.hpp | 23 |
1 files changed, 10 insertions, 13 deletions
diff --git a/delegate/src/Comparison.hpp b/delegate/src/Comparison.hpp index 80354e835d..688f90c597 100644 --- a/delegate/src/Comparison.hpp +++ b/delegate/src/Comparison.hpp @@ -57,10 +57,17 @@ TfLiteStatus VisitComparisonOperator(DelegateData& delegateData, return kTfLiteError; } - const armnn::TensorInfo& inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0); - const armnn::TensorInfo& inputTensorInfo1 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor1); + armnn::TensorInfo inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0); + armnn::TensorInfo inputTensorInfo1 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor1); const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true); + // Check if we need to expand the dims of any of the input tensor infos. + // This is required for a few of the backends. + if(inputTensorInfo0.GetNumDimensions() != inputTensorInfo1.GetNumDimensions()) + { + ExpandTensorRankToEqual(inputTensorInfo0, inputTensorInfo1); + } + armnn::ComparisonOperation comparisonOperation = armnn::ComparisonOperation::Equal; switch(tfLiteComparisonOperatorCode) { @@ -122,17 +129,7 @@ TfLiteStatus VisitComparisonOperator(DelegateData& delegateData, return kTfLiteError; } - auto reshapeLayer = BroadcastTensor(inputTensorInfo0, - inputTensorInfo1, - comparisonLayer, - tfLiteContext, - tfLiteNode, - delegateData); - if (!reshapeLayer) - { - return kTfLiteError; - } - return kTfLiteOk; + return Connect(comparisonLayer, tfLiteNode, delegateData); } } // namespace armnnDelegate |