aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/ElementwiseBinary.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/src/ElementwiseBinary.hpp')
-rw-r--r--delegate/src/ElementwiseBinary.hpp64
1 files changed, 63 insertions, 1 deletions
diff --git a/delegate/src/ElementwiseBinary.hpp b/delegate/src/ElementwiseBinary.hpp
index 58d7aca0ee..0534c070be 100644
--- a/delegate/src/ElementwiseBinary.hpp
+++ b/delegate/src/ElementwiseBinary.hpp
@@ -6,6 +6,8 @@
#pragma once
#include "DelegateUtils.hpp"
+#include "MultiLayerFacade.hpp"
+#include "SharedFunctions.hpp"
#include <tensorflow/lite/builtin_ops.h>
#include <tensorflow/lite/c/builtin_op_data.h>
@@ -39,6 +41,7 @@ TfLiteStatus ValidateAddOperator(DelegateData& delegateData,
return isSupported ? kTfLiteOk : kTfLiteError;
}
+
TfLiteStatus ValidateDivOperator(DelegateData& delegateData,
TfLiteContext* tfLiteContext,
const armnn::TensorInfo& inputInfo1,
@@ -62,6 +65,35 @@ TfLiteStatus ValidateDivOperator(DelegateData& delegateData,
return isSupported ? kTfLiteOk : kTfLiteError;
}
+TfLiteStatus ValidateFloorDivOperator(DelegateData& delegateData,
+ TfLiteContext* tfLiteContext,
+ const armnn::TensorInfo& inputInfo1,
+ const armnn::TensorInfo& inputInfo2,
+ const armnn::TensorInfo& outputInfo)
+{
+ // need first to validate that the div operator is supported
+ // then that the floor operator is supported
+ TfLiteStatus status = ValidateDivOperator(delegateData, tfLiteContext, inputInfo1, inputInfo2, outputInfo);
+ if (status != kTfLiteOk)
+ {
+ return status;
+ }
+ // if the inputs and output of the div are all Signed32 we don't need to add the floor operator afterward.
+ if (AreAllSigned32(inputInfo1, inputInfo2, outputInfo))
+ {
+ return status;
+ }
+ // in case broadcasting is being done from one of the inputs to the div
+ // choose the full sized input tensor to pass to the floor validation routine
+ armnn::TensorInfo floorInputInfo = inputInfo1;
+ if (inputInfo1.GetNumDimensions() < inputInfo2.GetNumDimensions())
+ {
+ floorInputInfo = inputInfo2;
+ }
+ status = ValidateFloorOperator(delegateData, tfLiteContext, floorInputInfo, outputInfo);
+ return status;
+}
+
TfLiteStatus ValidateMaximumOperator(DelegateData& delegateData,
TfLiteContext* tfLiteContext,
const armnn::TensorInfo& inputInfo1,
@@ -154,6 +186,23 @@ TfLiteStatus ValidateSubOperator(DelegateData& delegateData,
return isSupported ? kTfLiteOk : kTfLiteError;
}
+std::pair<armnn::IConnectableLayer*, armnn::IConnectableLayer*> AddFloorDivLayer(
+ DelegateData& delegateData,
+ const armnn::TensorInfo& outputTensorInfo)
+{
+ armnn::IConnectableLayer* divisionLayer = delegateData.m_Network->AddDivisionLayer();
+ // if the output of the div is Signed32 the Floor layer is not required
+ if (armnn::DataType::Signed32 == outputTensorInfo.GetDataType())
+ {
+ return std::make_pair(divisionLayer, divisionLayer);
+ }
+ armnn::IOutputSlot& outputSlot = divisionLayer->GetOutputSlot(0);
+ outputSlot.SetTensorInfo(outputTensorInfo);
+ armnn::IConnectableLayer* floorLayer = delegateData.m_Network->AddFloorLayer();
+ outputSlot.Connect(floorLayer->GetInputSlot(0));
+ return std::make_pair(divisionLayer, floorLayer);
+}
+
TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData,
TfLiteContext* tfLiteContext,
TfLiteNode* tfLiteNode,
@@ -215,6 +264,12 @@ TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData,
inputTensorInfo0,
inputTensorInfo1,
outputTensorInfo);
+ case kTfLiteBuiltinFloorDiv:
+ return ValidateFloorDivOperator(delegateData,
+ tfLiteContext,
+ inputTensorInfo0,
+ inputTensorInfo1,
+ outputTensorInfo);
case kTfLiteBuiltinMaximum:
return ValidateMaximumOperator(delegateData,
tfLiteContext,
@@ -245,7 +300,7 @@ TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData,
}
armnn::IConnectableLayer* elementwiseBinaryLayer = nullptr;
-
+ MultiLayerFacade multiLayer;
switch(elementwiseBinaryOperatorCode)
{
case kTfLiteBuiltinAdd:
@@ -254,6 +309,13 @@ TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData,
case kTfLiteBuiltinDiv:
elementwiseBinaryLayer = delegateData.m_Network->AddDivisionLayer();
break;
+ case kTfLiteBuiltinFloorDiv:
+ {
+ auto layers = AddFloorDivLayer(delegateData, outputTensorInfo);
+ multiLayer.AssignValues(layers.first, layers.second);
+ elementwiseBinaryLayer = &multiLayer;
+ }
+ break;
case kTfLiteBuiltinMaximum:
elementwiseBinaryLayer = delegateData.m_Network->AddMaximumLayer();
break;