aboutsummaryrefslogtreecommitdiff
path: root/delegate
diff options
context:
space:
mode:
Diffstat (limited to 'delegate')
-rw-r--r--delegate/src/ElementwiseBinary.hpp52
1 files changed, 33 insertions, 19 deletions
diff --git a/delegate/src/ElementwiseBinary.hpp b/delegate/src/ElementwiseBinary.hpp
index 52c6b2434b..fa9021b5c1 100644
--- a/delegate/src/ElementwiseBinary.hpp
+++ b/delegate/src/ElementwiseBinary.hpp
@@ -27,15 +27,17 @@ TfLiteStatus ValidateAddOperator(DelegateData& delegateData,
bool isSupported = false;
auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
{
+ std::vector<armnn::TensorInfo> infos { inputInfo1, inputInfo2, outputInfo };
FORWARD_LAYER_SUPPORT_FUNC("ADD",
tfLiteContext,
- IsAdditionSupported,
+ IsElementwiseBinarySupported,
delegateData.m_Backends,
isSupported,
armnn::BackendId(),
inputInfo1,
inputInfo2,
- outputTensorInfo);
+ outputInfo,
+ armnn::BinaryOperation::Add);
};
validateFunc(outputInfo, isSupported);
@@ -54,13 +56,14 @@ TfLiteStatus ValidateDivOperator(DelegateData& delegateData,
{
FORWARD_LAYER_SUPPORT_FUNC("DIV",
tfLiteContext,
- IsDivisionSupported,
+ IsElementwiseBinarySupported,
delegateData.m_Backends,
isSupported,
armnn::BackendId(),
inputInfo1,
inputInfo2,
- outputTensorInfo);
+ outputTensorInfo,
+ armnn::BinaryOperation::Div);
};
validateFunc(outputInfo, isSupported);
@@ -107,13 +110,14 @@ TfLiteStatus ValidateMaximumOperator(DelegateData& delegateData,
{
FORWARD_LAYER_SUPPORT_FUNC("MAXIMUM",
tfLiteContext,
- IsMaximumSupported,
+ IsElementwiseBinarySupported,
delegateData.m_Backends,
isSupported,
armnn::BackendId(),
inputInfo1,
inputInfo2,
- outputTensorInfo);
+ outputTensorInfo,
+ armnn::BinaryOperation::Maximum);
};
validateFunc(outputInfo, isSupported);
@@ -131,13 +135,14 @@ TfLiteStatus ValidateMinimumOperator(DelegateData& delegateData,
{
FORWARD_LAYER_SUPPORT_FUNC("MINIMUM",
tfLiteContext,
- IsMinimumSupported,
+ IsElementwiseBinarySupported,
delegateData.m_Backends,
isSupported,
armnn::BackendId(),
inputInfo1,
inputInfo2,
- outputTensorInfo);
+ outputTensorInfo,
+ armnn::BinaryOperation::Minimum);
};
validateFunc(outputInfo, isSupported);
@@ -155,13 +160,14 @@ TfLiteStatus ValidateMulOperator(DelegateData& delegateData,
{
FORWARD_LAYER_SUPPORT_FUNC("MUL",
tfLiteContext,
- IsMultiplicationSupported,
+ IsElementwiseBinarySupported,
delegateData.m_Backends,
isSupported,
armnn::BackendId(),
inputInfo1,
inputInfo2,
- outputTensorInfo);
+ outputTensorInfo,
+ armnn::BinaryOperation::Mul);
};
validateFunc(outputInfo, isSupported);
@@ -179,13 +185,14 @@ TfLiteStatus ValidateSubOperator(DelegateData& delegateData,
{
FORWARD_LAYER_SUPPORT_FUNC("SUB",
tfLiteContext,
- IsSubtractionSupported,
+ IsElementwiseBinarySupported,
delegateData.m_Backends,
isSupported,
armnn::BackendId(),
inputInfo1,
inputInfo2,
- outputTensorInfo);
+ outputTensorInfo,
+ armnn::BinaryOperation::Sub);
};
validateFunc(outputInfo, isSupported);
@@ -196,7 +203,8 @@ std::pair<armnn::IConnectableLayer*, armnn::IConnectableLayer*> AddFloorDivLayer
DelegateData& delegateData,
const armnn::TensorInfo& outputTensorInfo)
{
- armnn::IConnectableLayer* divisionLayer = delegateData.m_Network->AddDivisionLayer();
+ armnn::IConnectableLayer* divisionLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
+ armnn::BinaryOperation::Div);
// if the output of the div is Signed32 the Floor layer is not required
if (armnn::DataType::Signed32 == outputTensorInfo.GetDataType())
{
@@ -330,10 +338,12 @@ TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData,
switch(elementwiseBinaryOperatorCode)
{
case kTfLiteBuiltinAdd:
- elementwiseBinaryLayer = delegateData.m_Network->AddAdditionLayer();
+ elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
+ armnn::BinaryOperation::Add);
break;
case kTfLiteBuiltinDiv:
- elementwiseBinaryLayer = delegateData.m_Network->AddDivisionLayer();
+ elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
+ armnn::BinaryOperation::Div);
break;
case kTfLiteBuiltinFloorDiv:
{
@@ -343,16 +353,20 @@ TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData,
}
break;
case kTfLiteBuiltinMaximum:
- elementwiseBinaryLayer = delegateData.m_Network->AddMaximumLayer();
+ elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
+ armnn::BinaryOperation::Maximum);
break;
case kTfLiteBuiltinMinimum:
- elementwiseBinaryLayer = delegateData.m_Network->AddMinimumLayer();
+ elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
+ armnn::BinaryOperation::Minimum);
break;
case kTfLiteBuiltinMul:
- elementwiseBinaryLayer = delegateData.m_Network->AddMultiplicationLayer();
+ elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
+ armnn::BinaryOperation::Mul);
break;
case kTfLiteBuiltinSub:
- elementwiseBinaryLayer = delegateData.m_Network->AddSubtractionLayer();
+ elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
+ armnn::BinaryOperation::Sub);
break;
default:
return kTfLiteError;