aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/ElementwiseBinary.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/src/ElementwiseBinary.hpp')
-rw-r--r--delegate/src/ElementwiseBinary.hpp37
1 files changed, 33 insertions, 4 deletions
diff --git a/delegate/src/ElementwiseBinary.hpp b/delegate/src/ElementwiseBinary.hpp
index 3d3f1a0799..e5270057f5 100644
--- a/delegate/src/ElementwiseBinary.hpp
+++ b/delegate/src/ElementwiseBinary.hpp
@@ -11,6 +11,7 @@
#include <tensorflow/lite/c/builtin_op_data.h>
#include <tensorflow/lite/c/common.h>
#include <tensorflow/lite/minimal_logging.h>
+#include "tensorflow/lite/delegates/utils.h"
namespace armnnDelegate
{
@@ -193,8 +194,9 @@ TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData,
return kTfLiteError;
}
- const armnn::TensorInfo& inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0);
- const armnn::TensorInfo& inputTensorInfo1 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor1);
+ armnn::TensorInfo inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0);
+ armnn::TensorInfo inputTensorInfo1 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor1);
+
const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
if (!delegateData.m_Network)
@@ -268,10 +270,37 @@ TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData,
return kTfLiteError;
}
ARMNN_ASSERT(elementwiseBinaryLayer != nullptr);
-
armnn::IOutputSlot& outputSlot = elementwiseBinaryLayer->GetOutputSlot(0);
outputSlot.SetTensorInfo(outputTensorInfo);
+ if(tflite::IsConstantTensor(&tfLiteInputTensor0))
+ {
+ auto status = ConnectConstant(elementwiseBinaryLayer,
+ inputTensorInfo0,
+ tfLiteContext,
+ tfLiteInputTensor0,
+ delegateData,
+ tfLiteNode->inputs->data[0]);
+ if (status == kTfLiteError)
+ {
+ return status;
+ }
+ }
+
+ if(tflite::IsConstantTensor(&tfLiteInputTensor1))
+ {
+ auto status = ConnectConstant(elementwiseBinaryLayer,
+ inputTensorInfo1,
+ tfLiteContext,
+ tfLiteInputTensor1,
+ delegateData,
+ tfLiteNode->inputs->data[1]);
+ if (status == kTfLiteError)
+ {
+ return status;
+ }
+ }
+
auto reshapeLayer = BroadcastTensor(inputTensorInfo0,
inputTensorInfo1,
elementwiseBinaryLayer,
@@ -291,7 +320,7 @@ TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData,
}
// Check activation
TfLiteFusedActivation activationType = tfLiteNodeParameters->activation;
- return FusedActivation(tfLiteContext, tfLiteNode, activationType, reshapeLayer, 0, delegateData);
+ return FusedActivation(tfLiteContext, tfLiteNode, activationType, elementwiseBinaryLayer, 0, delegateData);
}
} // namespace armnnDelegate