aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/FullyConnected.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/src/FullyConnected.hpp')
-rw-r--r--delegate/src/FullyConnected.hpp24
1 files changed, 19 insertions, 5 deletions
diff --git a/delegate/src/FullyConnected.hpp b/delegate/src/FullyConnected.hpp
index 2243ad0e0c..ee553ce81c 100644
--- a/delegate/src/FullyConnected.hpp
+++ b/delegate/src/FullyConnected.hpp
@@ -57,6 +57,22 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
armnn::TensorInfo weightsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteWeightsTensor);
const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
+ // Check that we support fused activation before we attempt to create a layer
+ auto* tfLiteNodeParameters = reinterpret_cast<TfLiteFullyConnectedParams *>(tfLiteNode->builtin_data);
+ TfLiteFusedActivation activationType;
+ if (tfLiteNodeParameters)
+ {
+ activationType = tfLiteNodeParameters->activation;
+
+ const armnn::TensorInfo& activationOutputInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
+ TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo,
+ outputTensorInfo, activationType);
+ if(activationStatus != kTfLiteOk)
+ {
+ return kTfLiteError;
+ }
+ }
+
// Fully Connected Layer accepts two dimensional weights input
int32_t weightsDimension = static_cast<int32_t>(weightsTensorInfo.GetNumDimensions());
if (weightsDimension != 2)
@@ -221,9 +237,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
{
Connect(layer, tfLiteNode, delegateData);
}
-
- auto* tfLiteNodeParameters = reinterpret_cast<TfLiteFullyConnectedParams*>(tfLiteNode->builtin_data);
-
+
if (outputTensorInfo.GetNumDimensions() > 2)
{
layer = AddReshapeLayer(tfLiteContext, tfLiteNode, layer, reshapedOutputTensorInfo, outputTensorInfo,
@@ -244,8 +258,8 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
// No Activation
return kTfLiteOk;
}
- // Check Activation
- TfLiteFusedActivation activationType = tfLiteNodeParameters->activation;
+
+ // Check and Create Activation
return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData);
}