aboutsummaryrefslogtreecommitdiff
path: root/delegate/opaque/src/Comparison.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/opaque/src/Comparison.hpp')
-rw-r--r--delegate/opaque/src/Comparison.hpp71
1 files changed, 36 insertions, 35 deletions
diff --git a/delegate/opaque/src/Comparison.hpp b/delegate/opaque/src/Comparison.hpp
index 046be83094..8740cfb0ea 100644
--- a/delegate/opaque/src/Comparison.hpp
+++ b/delegate/opaque/src/Comparison.hpp
@@ -7,19 +7,44 @@
#include <OpaqueDelegateUtils.hpp>
-#include <tensorflow/lite/builtin_ops.h>
-#include <tensorflow/lite/c/builtin_op_data.h>
-#include <tensorflow/lite/c/common.h>
-#include <tensorflow/lite/minimal_logging.h>
-
namespace armnnOpaqueDelegate
{
+std::string GetLayerName(armnn::ComparisonOperation comparisonOperation)
+{
+ std::string layerName = "COMPARISON";
+ switch (comparisonOperation)
+ {
+ case armnn::ComparisonOperation::NotEqual:
+ layerName += " NOT_EQUAL";
+ break;
+ case armnn::ComparisonOperation::Equal:
+ layerName += " EQUAL";
+ break;
+ case armnn::ComparisonOperation::Greater:
+ layerName += " GREATER";
+ break;
+ case armnn::ComparisonOperation::GreaterOrEqual:
+ layerName += " GREATER_OR_EQUAL";
+ break;
+ case armnn::ComparisonOperation::Less:
+ layerName += " LESS";
+ break;
+ case armnn::ComparisonOperation::LessOrEqual:
+ layerName += " LESS_OR_EQUAL";
+ break;
+ default:
+ layerName += " UNKNOWN";
+ }
+ return layerName;
+}
+
TfLiteStatus VisitComparisonOperator(DelegateData& delegateData,
TfLiteOpaqueContext* tfLiteContext,
TfLiteOpaqueNode* tfLiteNode,
int nodeIndex,
- int32_t tfLiteComparisonOperatorCode)
+ int32_t tfLiteComparisonOperatorCode,
+ armnn::ComparisonOperation comparisonOperation)
{
TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
@@ -61,6 +86,7 @@ TfLiteStatus VisitComparisonOperator(DelegateData& delegateData,
return kTfLiteError;
}
+ // Use output indices to get output tensor.
const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteComparisonOperatorCode, nodeIndex))
{
@@ -78,37 +104,12 @@ TfLiteStatus VisitComparisonOperator(DelegateData& delegateData,
ExpandTensorRankToEqual(inputTensorInfo0, inputTensorInfo1);
}
- armnn::ComparisonOperation comparisonOperation = armnn::ComparisonOperation::Equal;
- switch(tfLiteComparisonOperatorCode)
- {
- case kTfLiteBuiltinEqual:
- comparisonOperation = armnn::ComparisonOperation::Equal;
- break;
- case kTfLiteBuiltinGreater:
- comparisonOperation = armnn::ComparisonOperation::Greater;
- break;
- case kTfLiteBuiltinGreaterEqual:
- comparisonOperation = armnn::ComparisonOperation::GreaterOrEqual;
- break;
- case kTfLiteBuiltinLess:
- comparisonOperation = armnn::ComparisonOperation::Less;
- break;
- case kTfLiteBuiltinLessEqual:
- comparisonOperation = armnn::ComparisonOperation::LessOrEqual;
- break;
- case kTfLiteBuiltinNotEqual:
- comparisonOperation = armnn::ComparisonOperation::NotEqual;
- break;
- default:
- return kTfLiteError;
- }
-
armnn::ComparisonDescriptor descriptor(comparisonOperation);
bool isSupported = false;
armnn::BackendId setBackend;
- auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
+ auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported, std::string layerName)
{
- FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("COMPARISON",
+ FORWARD_LAYER_OPAQUE_SUPPORT_FUNC(layerName.c_str(),
tfLiteContext,
IsComparisonSupported,
delegateData.m_Backends,
@@ -122,7 +123,7 @@ TfLiteStatus VisitComparisonOperator(DelegateData& delegateData,
if (!delegateData.m_Network)
{
- validateFunc(outputTensorInfo, isSupported);
+ validateFunc(outputTensorInfo, isSupported, GetLayerName(comparisonOperation));
return isSupported ? kTfLiteOk : kTfLiteError;
}
@@ -142,4 +143,4 @@ TfLiteStatus VisitComparisonOperator(DelegateData& delegateData,
return Connect(comparisonLayer, tfLiteContext, tfLiteNode, delegateData);
}
-} // namespace armnnDelegate
+} // namespace armnnOpaqueDelegate