diff options
Diffstat (limited to 'delegate/opaque/src/armnn_delegate.cpp')
-rw-r--r-- | delegate/opaque/src/armnn_delegate.cpp | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/delegate/opaque/src/armnn_delegate.cpp b/delegate/opaque/src/armnn_delegate.cpp index 3b647f3531..ead577f806 100644 --- a/delegate/opaque/src/armnn_delegate.cpp +++ b/delegate/opaque/src/armnn_delegate.cpp @@ -641,6 +641,12 @@ TfLiteStatus ArmnnSubgraph::VisitNode(DelegateData& delegateData, tfLiteNode, nodeIndex, kTfLiteBuiltinArgMin); + case kTfLiteBuiltinAveragePool2d: + return VisitPooling2dOperator(delegateData, + tfLiteContext, + tfLiteNode, + nodeIndex, + kTfLiteBuiltinAveragePool2d); case kTfLiteBuiltinBatchMatmul: return VisitBatchMatMulOperator(delegateData, tfLiteContext, @@ -684,6 +690,30 @@ TfLiteStatus ArmnnSubgraph::VisitNode(DelegateData& delegateData, tfLiteNode, nodeIndex, kTfLiteBuiltinConv3d); + case kTfLiteBuiltinCustom: + { + // Custom operators are defined by the name rather than the builtin code. + // Parse the custom_name param in the registration to point to the correct visitor function. + std::string customOperatorName = TfLiteRegistrationExternalGetCustomName(tfLiteRegistration); + if ( customOperatorName == "AveragePool3D" ) + { + return VisitPooling3dOperator(delegateData, + tfLiteContext, + tfLiteNode, + nodeIndex, + customOperatorName); + } + else if (customOperatorName == "MaxPool3D") + { + return VisitPooling3dOperator(delegateData, + tfLiteContext, + tfLiteNode, + nodeIndex, + customOperatorName); + } + // Invalid or unsupported custom operator + return kTfLiteError; + } case kTfLiteBuiltinDepthwiseConv2d: return VisitConvolutionOperator(delegateData, tfLiteContext, @@ -710,6 +740,12 @@ TfLiteStatus ArmnnSubgraph::VisitNode(DelegateData& delegateData, nodeIndex, kTfLiteBuiltinExp, armnn::UnaryOperation::Exp); + case kTfLiteBuiltinFloor: + return VisitFloorOperator(delegateData, + tfLiteContext, + tfLiteNode, + nodeIndex, + kTfLiteBuiltinFloor); case kTfLiteBuiltinFullyConnected: return VisitFullyConnectedOperator(delegateData, tfLiteContext, @@ -754,6 +790,12 @@ TfLiteStatus ArmnnSubgraph::VisitNode(DelegateData& delegateData, tfLiteNode, nodeIndex, kTfLiteBuiltinL2Normalization); + case kTfLiteBuiltinL2Pool2d: + return VisitPooling2dOperator(delegateData, + tfLiteContext, + tfLiteNode, + nodeIndex, + kTfLiteBuiltinL2Pool2d); case kTfLiteBuiltinLess: return VisitComparisonOperator(delegateData, tfLiteContext, @@ -808,6 +850,18 @@ TfLiteStatus ArmnnSubgraph::VisitNode(DelegateData& delegateData, nodeIndex, kTfLiteBuiltinLogicalOr, armnn::LogicalBinaryOperation::LogicalOr); + case kTfLiteBuiltinLstm: + return VisitLstmOperator(delegateData, + tfLiteContext, + tfLiteNode, + nodeIndex, + kTfLiteBuiltinLstm); + case kTfLiteBuiltinMaxPool2d: + return VisitPooling2dOperator(delegateData, + tfLiteContext, + tfLiteNode, + nodeIndex, + kTfLiteBuiltinMaxPool2d); case kTfLiteBuiltinMean: return VisitControlOperator(delegateData, tfLiteContext, |