diff options
author | Ryan OShea <ryan.oshea3@arm.com> | 2022-06-10 14:49:11 +0100 |
---|---|---|
committer | Nikhil Raj <nikhil.raj@arm.com> | 2022-06-16 12:46:46 +0100 |
commit | 57480cdc5b2b3a88b6fca40089d0fb7521d832b2 (patch) | |
tree | 95cd7afbc1a6e2a748a7915f9bfaa033116cc5c8 /delegate/src/armnn_delegate.cpp | |
parent | 7bbb79b53f95af00cfbf888fd97b1bdca81612ed (diff) | |
download | armnn-57480cdc5b2b3a88b6fca40089d0fb7521d832b2.tar.gz |
IVGCVSW-6946 Add Pool3D to tflite delegate
* Add new test and test helper for Pool3d
* Add new custom operator to switch in armnn_delegate.cpp
* Add new pool3d function to pooling.hpp
* Update doxygen
Signed-off-by: Ryan OShea <ryan.oshea3@arm.com>
Change-Id: I77a541bf423b337c749e70c564cdd727efe2fd05
Diffstat (limited to 'delegate/src/armnn_delegate.cpp')
-rw-r--r-- | delegate/src/armnn_delegate.cpp | 34 |
1 files changed, 30 insertions, 4 deletions
diff --git a/delegate/src/armnn_delegate.cpp b/delegate/src/armnn_delegate.cpp index 6e1a91f9e4..bb2f3c319a 100644 --- a/delegate/src/armnn_delegate.cpp +++ b/delegate/src/armnn_delegate.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -495,6 +495,32 @@ TfLiteStatus ArmnnSubgraph::VisitNode(DelegateData& delegateData, { switch (tfLiteRegistration->builtin_code) { + case kTfLiteBuiltinCustom: + { +#if defined(ARMNN_POST_TFLITE_2_5) + // Custom operators are defined by the name rather than the builtin code. + // Parse the custom_name param in the registration to point to the correct visitor function. + std::string customOperatorName = tfLiteRegistration->custom_name; + if ( customOperatorName == "AveragePool3D" ) + { + return VisitPooling3dOperator(delegateData, + tfLiteContext, + tfLiteNode, + nodeIndex, + customOperatorName); + } + else if (customOperatorName == "MaxPool3D") + { + return VisitPooling3dOperator(delegateData, + tfLiteContext, + tfLiteNode, + nodeIndex, + customOperatorName); + } +#endif + // Invalid or unsupported custom operator + return kTfLiteError; + } case kTfLiteBuiltinAbs: return VisitElementwiseUnaryOperator(delegateData, tfLiteContext, @@ -520,7 +546,7 @@ TfLiteStatus ArmnnSubgraph::VisitNode(DelegateData& delegateData, nodeIndex, kTfLiteBuiltinArgMin); case kTfLiteBuiltinAveragePool2d: - return VisitPoolingOperator(delegateData, + return VisitPooling2dOperator(delegateData, tfLiteContext, tfLiteNode, nodeIndex, @@ -667,7 +693,7 @@ TfLiteStatus ArmnnSubgraph::VisitNode(DelegateData& delegateData, nodeIndex, kTfLiteBuiltinL2Normalization); case kTfLiteBuiltinL2Pool2d: - return VisitPoolingOperator(delegateData, + return VisitPooling2dOperator(delegateData, tfLiteContext, tfLiteNode, nodeIndex, @@ -729,7 +755,7 @@ TfLiteStatus ArmnnSubgraph::VisitNode(DelegateData& delegateData, nodeIndex, kTfLiteBuiltinLstm); case kTfLiteBuiltinMaxPool2d: - return VisitPoolingOperator(delegateData, + return VisitPooling2dOperator(delegateData, tfLiteContext, tfLiteNode, nodeIndex, |