diff options
Diffstat (limited to 'delegate/src/ElementwiseBinary.hpp')
-rw-r--r-- | delegate/src/ElementwiseBinary.hpp | 161 |
1 files changed, 155 insertions, 6 deletions
diff --git a/delegate/src/ElementwiseBinary.hpp b/delegate/src/ElementwiseBinary.hpp index a22d9f5751..3d3f1a0799 100644 --- a/delegate/src/ElementwiseBinary.hpp +++ b/delegate/src/ElementwiseBinary.hpp @@ -38,15 +38,119 @@ TfLiteStatus ValidateAddOperator(DelegateData& delegateData, return isSupported ? kTfLiteOk : kTfLiteError; } -armnn::IConnectableLayer* AddAdditionLayer(DelegateData& delegateData) +TfLiteStatus ValidateDivOperator(DelegateData& delegateData, + TfLiteContext* tfLiteContext, + const armnn::TensorInfo& inputInfo1, + const armnn::TensorInfo& inputInfo2, + const armnn::TensorInfo& outputInfo) { + bool isSupported = false; + auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported) + { + FORWARD_LAYER_SUPPORT_FUNC(__func__, + tfLiteContext, + IsDivisionSupported, + delegateData.m_Backends, + isSupported, + inputInfo1, + inputInfo2, + outputTensorInfo); + }; - if (!delegateData.m_Network) + validateFunc(outputInfo, isSupported); + return isSupported ? kTfLiteOk : kTfLiteError; +} + +TfLiteStatus ValidateMaximumOperator(DelegateData& delegateData, + TfLiteContext* tfLiteContext, + const armnn::TensorInfo& inputInfo1, + const armnn::TensorInfo& inputInfo2, + const armnn::TensorInfo& outputInfo) +{ + bool isSupported = false; + auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported) { - return nullptr; - } + FORWARD_LAYER_SUPPORT_FUNC(__func__, + tfLiteContext, + IsMaximumSupported, + delegateData.m_Backends, + isSupported, + inputInfo1, + inputInfo2, + outputTensorInfo); + }; + + validateFunc(outputInfo, isSupported); + return isSupported ? kTfLiteOk : kTfLiteError; +} + +TfLiteStatus ValidateMinimumOperator(DelegateData& delegateData, + TfLiteContext* tfLiteContext, + const armnn::TensorInfo& inputInfo1, + const armnn::TensorInfo& inputInfo2, + const armnn::TensorInfo& outputInfo) +{ + bool isSupported = false; + auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported) + { + FORWARD_LAYER_SUPPORT_FUNC(__func__, + tfLiteContext, + IsMinimumSupported, + delegateData.m_Backends, + isSupported, + inputInfo1, + inputInfo2, + outputTensorInfo); + }; + + validateFunc(outputInfo, isSupported); + return isSupported ? kTfLiteOk : kTfLiteError; +} + +TfLiteStatus ValidateMulOperator(DelegateData& delegateData, + TfLiteContext* tfLiteContext, + const armnn::TensorInfo& inputInfo1, + const armnn::TensorInfo& inputInfo2, + const armnn::TensorInfo& outputInfo) +{ + bool isSupported = false; + auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported) + { + FORWARD_LAYER_SUPPORT_FUNC(__func__, + tfLiteContext, + IsMultiplicationSupported, + delegateData.m_Backends, + isSupported, + inputInfo1, + inputInfo2, + outputTensorInfo); + }; - return delegateData.m_Network->AddAdditionLayer(); + validateFunc(outputInfo, isSupported); + return isSupported ? kTfLiteOk : kTfLiteError; +} + +TfLiteStatus ValidateSubOperator(DelegateData& delegateData, + TfLiteContext* tfLiteContext, + const armnn::TensorInfo& inputInfo1, + const armnn::TensorInfo& inputInfo2, + const armnn::TensorInfo& outputInfo) +{ + bool isSupported = false; + auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported) + { + FORWARD_LAYER_SUPPORT_FUNC(__func__, + tfLiteContext, + IsSubtractionSupported, + delegateData.m_Backends, + isSupported, + inputInfo1, + inputInfo2, + outputTensorInfo); + }; + + validateFunc(outputInfo, isSupported); + return isSupported ? kTfLiteOk : kTfLiteError; } TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData, @@ -103,6 +207,36 @@ TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData, inputTensorInfo0, inputTensorInfo1, outputTensorInfo); + case kTfLiteBuiltinDiv: + return ValidateDivOperator(delegateData, + tfLiteContext, + inputTensorInfo0, + inputTensorInfo1, + outputTensorInfo); + case kTfLiteBuiltinMaximum: + return ValidateMaximumOperator(delegateData, + tfLiteContext, + inputTensorInfo0, + inputTensorInfo1, + outputTensorInfo); + case kTfLiteBuiltinMinimum: + return ValidateMinimumOperator(delegateData, + tfLiteContext, + inputTensorInfo0, + inputTensorInfo1, + outputTensorInfo); + case kTfLiteBuiltinMul: + return ValidateDivOperator(delegateData, + tfLiteContext, + inputTensorInfo0, + inputTensorInfo1, + outputTensorInfo); + case kTfLiteBuiltinSub: + return ValidateDivOperator(delegateData, + tfLiteContext, + inputTensorInfo0, + inputTensorInfo1, + outputTensorInfo); default: return kTfLiteError; } @@ -113,7 +247,22 @@ TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData, switch(elementwiseBinaryOperatorCode) { case kTfLiteBuiltinAdd: - elementwiseBinaryLayer = AddAdditionLayer(delegateData); + elementwiseBinaryLayer = delegateData.m_Network->AddAdditionLayer(); + break; + case kTfLiteBuiltinDiv: + elementwiseBinaryLayer = delegateData.m_Network->AddDivisionLayer(); + break; + case kTfLiteBuiltinMaximum: + elementwiseBinaryLayer = delegateData.m_Network->AddMaximumLayer(); + break; + case kTfLiteBuiltinMinimum: + elementwiseBinaryLayer = delegateData.m_Network->AddMinimumLayer(); + break; + case kTfLiteBuiltinMul: + elementwiseBinaryLayer = delegateData.m_Network->AddMultiplicationLayer(); + break; + case kTfLiteBuiltinSub: + elementwiseBinaryLayer = delegateData.m_Network->AddSubtractionLayer(); break; default: return kTfLiteError; |