diff options
Diffstat (limited to 'delegate/src/ElementwiseBinary.hpp')
-rw-r--r-- | delegate/src/ElementwiseBinary.hpp | 37 |
1 files changed, 33 insertions, 4 deletions
diff --git a/delegate/src/ElementwiseBinary.hpp b/delegate/src/ElementwiseBinary.hpp index 3d3f1a0799..e5270057f5 100644 --- a/delegate/src/ElementwiseBinary.hpp +++ b/delegate/src/ElementwiseBinary.hpp @@ -11,6 +11,7 @@ #include <tensorflow/lite/c/builtin_op_data.h> #include <tensorflow/lite/c/common.h> #include <tensorflow/lite/minimal_logging.h> +#include "tensorflow/lite/delegates/utils.h" namespace armnnDelegate { @@ -193,8 +194,9 @@ TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData, return kTfLiteError; } - const armnn::TensorInfo& inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0); - const armnn::TensorInfo& inputTensorInfo1 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor1); + armnn::TensorInfo inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0); + armnn::TensorInfo inputTensorInfo1 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor1); + const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor); if (!delegateData.m_Network) @@ -268,10 +270,37 @@ TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData, return kTfLiteError; } ARMNN_ASSERT(elementwiseBinaryLayer != nullptr); - armnn::IOutputSlot& outputSlot = elementwiseBinaryLayer->GetOutputSlot(0); outputSlot.SetTensorInfo(outputTensorInfo); + if(tflite::IsConstantTensor(&tfLiteInputTensor0)) + { + auto status = ConnectConstant(elementwiseBinaryLayer, + inputTensorInfo0, + tfLiteContext, + tfLiteInputTensor0, + delegateData, + tfLiteNode->inputs->data[0]); + if (status == kTfLiteError) + { + return status; + } + } + + if(tflite::IsConstantTensor(&tfLiteInputTensor1)) + { + auto status = ConnectConstant(elementwiseBinaryLayer, + inputTensorInfo1, + tfLiteContext, + tfLiteInputTensor1, + delegateData, + tfLiteNode->inputs->data[1]); + if (status == kTfLiteError) + { + return status; + } + } + auto reshapeLayer = BroadcastTensor(inputTensorInfo0, inputTensorInfo1, elementwiseBinaryLayer, @@ -291,7 +320,7 @@ TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData, } // Check activation TfLiteFusedActivation activationType = tfLiteNodeParameters->activation; - return FusedActivation(tfLiteContext, tfLiteNode, activationType, reshapeLayer, 0, delegateData); + return FusedActivation(tfLiteContext, tfLiteNode, activationType, elementwiseBinaryLayer, 0, delegateData); } } // namespace armnnDelegate |