diff options
Diffstat (limited to 'delegate/src/ElementwiseBinary.hpp')
-rw-r--r-- | delegate/src/ElementwiseBinary.hpp | 64 |
1 files changed, 63 insertions, 1 deletions
diff --git a/delegate/src/ElementwiseBinary.hpp b/delegate/src/ElementwiseBinary.hpp index 58d7aca0ee..0534c070be 100644 --- a/delegate/src/ElementwiseBinary.hpp +++ b/delegate/src/ElementwiseBinary.hpp @@ -6,6 +6,8 @@ #pragma once #include "DelegateUtils.hpp" +#include "MultiLayerFacade.hpp" +#include "SharedFunctions.hpp" #include <tensorflow/lite/builtin_ops.h> #include <tensorflow/lite/c/builtin_op_data.h> @@ -39,6 +41,7 @@ TfLiteStatus ValidateAddOperator(DelegateData& delegateData, return isSupported ? kTfLiteOk : kTfLiteError; } + TfLiteStatus ValidateDivOperator(DelegateData& delegateData, TfLiteContext* tfLiteContext, const armnn::TensorInfo& inputInfo1, @@ -62,6 +65,35 @@ TfLiteStatus ValidateDivOperator(DelegateData& delegateData, return isSupported ? kTfLiteOk : kTfLiteError; } +TfLiteStatus ValidateFloorDivOperator(DelegateData& delegateData, + TfLiteContext* tfLiteContext, + const armnn::TensorInfo& inputInfo1, + const armnn::TensorInfo& inputInfo2, + const armnn::TensorInfo& outputInfo) +{ + // need first to validate that the div operator is supported + // then that the floor operator is supported + TfLiteStatus status = ValidateDivOperator(delegateData, tfLiteContext, inputInfo1, inputInfo2, outputInfo); + if (status != kTfLiteOk) + { + return status; + } + // if the inputs and output of the div are all Signed32 we don't need to add the floor operator afterward. + if (AreAllSigned32(inputInfo1, inputInfo2, outputInfo)) + { + return status; + } + // in case broadcasting is being done from one of the inputs to the div + // choose the full sized input tensor to pass to the floor validation routine + armnn::TensorInfo floorInputInfo = inputInfo1; + if (inputInfo1.GetNumDimensions() < inputInfo2.GetNumDimensions()) + { + floorInputInfo = inputInfo2; + } + status = ValidateFloorOperator(delegateData, tfLiteContext, floorInputInfo, outputInfo); + return status; +} + TfLiteStatus ValidateMaximumOperator(DelegateData& delegateData, TfLiteContext* tfLiteContext, const armnn::TensorInfo& inputInfo1, @@ -154,6 +186,23 @@ TfLiteStatus ValidateSubOperator(DelegateData& delegateData, return isSupported ? kTfLiteOk : kTfLiteError; } +std::pair<armnn::IConnectableLayer*, armnn::IConnectableLayer*> AddFloorDivLayer( + DelegateData& delegateData, + const armnn::TensorInfo& outputTensorInfo) +{ + armnn::IConnectableLayer* divisionLayer = delegateData.m_Network->AddDivisionLayer(); + // if the output of the div is Signed32 the Floor layer is not required + if (armnn::DataType::Signed32 == outputTensorInfo.GetDataType()) + { + return std::make_pair(divisionLayer, divisionLayer); + } + armnn::IOutputSlot& outputSlot = divisionLayer->GetOutputSlot(0); + outputSlot.SetTensorInfo(outputTensorInfo); + armnn::IConnectableLayer* floorLayer = delegateData.m_Network->AddFloorLayer(); + outputSlot.Connect(floorLayer->GetInputSlot(0)); + return std::make_pair(divisionLayer, floorLayer); +} + TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData, TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode, @@ -215,6 +264,12 @@ TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData, inputTensorInfo0, inputTensorInfo1, outputTensorInfo); + case kTfLiteBuiltinFloorDiv: + return ValidateFloorDivOperator(delegateData, + tfLiteContext, + inputTensorInfo0, + inputTensorInfo1, + outputTensorInfo); case kTfLiteBuiltinMaximum: return ValidateMaximumOperator(delegateData, tfLiteContext, @@ -245,7 +300,7 @@ TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData, } armnn::IConnectableLayer* elementwiseBinaryLayer = nullptr; - + MultiLayerFacade multiLayer; switch(elementwiseBinaryOperatorCode) { case kTfLiteBuiltinAdd: @@ -254,6 +309,13 @@ TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData, case kTfLiteBuiltinDiv: elementwiseBinaryLayer = delegateData.m_Network->AddDivisionLayer(); break; + case kTfLiteBuiltinFloorDiv: + { + auto layers = AddFloorDivLayer(delegateData, outputTensorInfo); + multiLayer.AssignValues(layers.first, layers.second); + elementwiseBinaryLayer = &multiLayer; + } + break; case kTfLiteBuiltinMaximum: elementwiseBinaryLayer = delegateData.m_Network->AddMaximumLayer(); break; |