diff options
Diffstat (limited to 'delegate/src/Convolution.hpp')
-rw-r--r-- | delegate/src/Convolution.hpp | 63 |
1 files changed, 53 insertions, 10 deletions
diff --git a/delegate/src/Convolution.hpp b/delegate/src/Convolution.hpp index e307bb9be3..7ea3a3a987 100644 --- a/delegate/src/Convolution.hpp +++ b/delegate/src/Convolution.hpp @@ -1,11 +1,12 @@ // -// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "DelegateUtils.hpp" +#include "SharedFunctions.hpp" #include <tensorflow/lite/builtin_ops.h> #include <tensorflow/lite/c/builtin_op_data.h> @@ -100,6 +101,22 @@ TfLiteStatus VisitConv2dOperator(DelegateData& delegateData, const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor); const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true); + auto* tfLiteNodeParameters = reinterpret_cast<TfLiteConvParams*>(tfLiteNode->builtin_data); + TfLiteFusedActivation activationType; + if (tfLiteNodeParameters) + { + activationType = tfLiteNodeParameters->activation; + + const armnn::TensorInfo& activationOutputInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true); + TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo, + outputTensorInfo, activationType); + if(activationStatus != kTfLiteOk) + { + return kTfLiteError; + } + + } + armnn::TensorInfo filterTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteFilterTensor); armnn::TensorInfo biasTensorInfo; @@ -198,14 +215,12 @@ TfLiteStatus VisitConv2dOperator(DelegateData& delegateData, Connect(layer, tfLiteNode, delegateData); - auto* tfLiteNodeParameters = reinterpret_cast<TfLiteConvParams*>(tfLiteNode->builtin_data); if (!tfLiteNodeParameters) { // No Activation return kTfLiteOk; } - // Check activation - TfLiteFusedActivation activationType = tfLiteNodeParameters->activation; + // Check and Create activation return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData); } @@ -263,6 +278,22 @@ TfLiteStatus VisitConv3dOperator(DelegateData& delegateData, const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor); const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true); + auto* tfLiteNodeParameters = reinterpret_cast<TfLiteConv3DParams*>(tfLiteNode->builtin_data); + TfLiteFusedActivation activationType; + if (tfLiteNodeParameters) + { + activationType = tfLiteNodeParameters->activation; + + const armnn::TensorInfo& activationOutputInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true); + TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo, + outputTensorInfo, activationType); + if(activationStatus != kTfLiteOk) + { + return kTfLiteError; + } + + } + armnn::TensorInfo filterTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteFilterTensor); armnn::TensorInfo biasTensorInfo; @@ -362,15 +393,13 @@ TfLiteStatus VisitConv3dOperator(DelegateData& delegateData, Connect(layer, tfLiteNode, delegateData); - auto* tfLiteNodeParameters = reinterpret_cast<TfLiteConv3DParams*>(tfLiteNode->builtin_data); if (!tfLiteNodeParameters) { // No Activation return kTfLiteOk; } - // Check activation - TfLiteFusedActivation activationType = tfLiteNodeParameters->activation; + // Check and create activation return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData); } #endif @@ -460,6 +489,22 @@ TfLiteStatus VisitDepthwiseConv2dOperator(DelegateData& delegateData, const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor); const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true); + auto* tfLiteNodeParameters = reinterpret_cast<TfLiteDepthwiseConvParams *>(tfLiteNode->builtin_data); + TfLiteFusedActivation activationType; + if (tfLiteNodeParameters) + { + activationType = tfLiteNodeParameters->activation; + + const armnn::TensorInfo& activationOutputInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true); + TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo, + outputTensorInfo, activationType); + if(activationStatus != kTfLiteOk) + { + return kTfLiteError; + } + + } + armnn::TensorInfo filterTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteFilterTensor); // Assuming input is NHWC @@ -553,14 +598,12 @@ TfLiteStatus VisitDepthwiseConv2dOperator(DelegateData& delegateData, outputSlot.SetTensorInfo(outputTensorInfo); Connect(layer, tfLiteNode, delegateData); - auto* tfLiteNodeParameters = reinterpret_cast<TfLiteDepthwiseConvParams*>(tfLiteNode->builtin_data); if (!tfLiteNodeParameters) { // No Activation return kTfLiteOk; } - // Check activation - TfLiteFusedActivation activationType = tfLiteNodeParameters->activation; + // Check and create activation return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData); } |