aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/Convolution.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/src/Convolution.hpp')
-rw-r--r--delegate/src/Convolution.hpp63
1 files changed, 53 insertions, 10 deletions
diff --git a/delegate/src/Convolution.hpp b/delegate/src/Convolution.hpp
index e307bb9be3..7ea3a3a987 100644
--- a/delegate/src/Convolution.hpp
+++ b/delegate/src/Convolution.hpp
@@ -1,11 +1,12 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include "DelegateUtils.hpp"
+#include "SharedFunctions.hpp"
#include <tensorflow/lite/builtin_ops.h>
#include <tensorflow/lite/c/builtin_op_data.h>
@@ -100,6 +101,22 @@ TfLiteStatus VisitConv2dOperator(DelegateData& delegateData,
const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
+ auto* tfLiteNodeParameters = reinterpret_cast<TfLiteConvParams*>(tfLiteNode->builtin_data);
+ TfLiteFusedActivation activationType;
+ if (tfLiteNodeParameters)
+ {
+ activationType = tfLiteNodeParameters->activation;
+
+ const armnn::TensorInfo& activationOutputInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
+ TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo,
+ outputTensorInfo, activationType);
+ if(activationStatus != kTfLiteOk)
+ {
+ return kTfLiteError;
+ }
+
+ }
+
armnn::TensorInfo filterTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteFilterTensor);
armnn::TensorInfo biasTensorInfo;
@@ -198,14 +215,12 @@ TfLiteStatus VisitConv2dOperator(DelegateData& delegateData,
Connect(layer, tfLiteNode, delegateData);
- auto* tfLiteNodeParameters = reinterpret_cast<TfLiteConvParams*>(tfLiteNode->builtin_data);
if (!tfLiteNodeParameters)
{
// No Activation
return kTfLiteOk;
}
- // Check activation
- TfLiteFusedActivation activationType = tfLiteNodeParameters->activation;
+ // Check and Create activation
return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData);
}
@@ -263,6 +278,22 @@ TfLiteStatus VisitConv3dOperator(DelegateData& delegateData,
const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
+ auto* tfLiteNodeParameters = reinterpret_cast<TfLiteConv3DParams*>(tfLiteNode->builtin_data);
+ TfLiteFusedActivation activationType;
+ if (tfLiteNodeParameters)
+ {
+ activationType = tfLiteNodeParameters->activation;
+
+ const armnn::TensorInfo& activationOutputInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
+ TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo,
+ outputTensorInfo, activationType);
+ if(activationStatus != kTfLiteOk)
+ {
+ return kTfLiteError;
+ }
+
+ }
+
armnn::TensorInfo filterTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteFilterTensor);
armnn::TensorInfo biasTensorInfo;
@@ -362,15 +393,13 @@ TfLiteStatus VisitConv3dOperator(DelegateData& delegateData,
Connect(layer, tfLiteNode, delegateData);
- auto* tfLiteNodeParameters = reinterpret_cast<TfLiteConv3DParams*>(tfLiteNode->builtin_data);
if (!tfLiteNodeParameters)
{
// No Activation
return kTfLiteOk;
}
- // Check activation
- TfLiteFusedActivation activationType = tfLiteNodeParameters->activation;
+ // Check and create activation
return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData);
}
#endif
@@ -460,6 +489,22 @@ TfLiteStatus VisitDepthwiseConv2dOperator(DelegateData& delegateData,
const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
+ auto* tfLiteNodeParameters = reinterpret_cast<TfLiteDepthwiseConvParams *>(tfLiteNode->builtin_data);
+ TfLiteFusedActivation activationType;
+ if (tfLiteNodeParameters)
+ {
+ activationType = tfLiteNodeParameters->activation;
+
+ const armnn::TensorInfo& activationOutputInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
+ TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo,
+ outputTensorInfo, activationType);
+ if(activationStatus != kTfLiteOk)
+ {
+ return kTfLiteError;
+ }
+
+ }
+
armnn::TensorInfo filterTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteFilterTensor);
// Assuming input is NHWC
@@ -553,14 +598,12 @@ TfLiteStatus VisitDepthwiseConv2dOperator(DelegateData& delegateData,
outputSlot.SetTensorInfo(outputTensorInfo);
Connect(layer, tfLiteNode, delegateData);
- auto* tfLiteNodeParameters = reinterpret_cast<TfLiteDepthwiseConvParams*>(tfLiteNode->builtin_data);
if (!tfLiteNodeParameters)
{
// No Activation
return kTfLiteOk;
}
- // Check activation
- TfLiteFusedActivation activationType = tfLiteNodeParameters->activation;
+ // Check and create activation
return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData);
}