diff options
Diffstat (limited to 'delegate/src/FullyConnected.hpp')
-rw-r--r-- | delegate/src/FullyConnected.hpp | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/delegate/src/FullyConnected.hpp b/delegate/src/FullyConnected.hpp index 2243ad0e0c..ee553ce81c 100644 --- a/delegate/src/FullyConnected.hpp +++ b/delegate/src/FullyConnected.hpp @@ -57,6 +57,22 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, armnn::TensorInfo weightsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteWeightsTensor); const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true); + // Check that we support fused activation before we attempt to create a layer + auto* tfLiteNodeParameters = reinterpret_cast<TfLiteFullyConnectedParams *>(tfLiteNode->builtin_data); + TfLiteFusedActivation activationType; + if (tfLiteNodeParameters) + { + activationType = tfLiteNodeParameters->activation; + + const armnn::TensorInfo& activationOutputInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true); + TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo, + outputTensorInfo, activationType); + if(activationStatus != kTfLiteOk) + { + return kTfLiteError; + } + } + // Fully Connected Layer accepts two dimensional weights input int32_t weightsDimension = static_cast<int32_t>(weightsTensorInfo.GetNumDimensions()); if (weightsDimension != 2) @@ -221,9 +237,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, { Connect(layer, tfLiteNode, delegateData); } - - auto* tfLiteNodeParameters = reinterpret_cast<TfLiteFullyConnectedParams*>(tfLiteNode->builtin_data); - + if (outputTensorInfo.GetNumDimensions() > 2) { layer = AddReshapeLayer(tfLiteContext, tfLiteNode, layer, reshapedOutputTensorInfo, outputTensorInfo, @@ -244,8 +258,8 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, // No Activation return kTfLiteOk; } - // Check Activation - TfLiteFusedActivation activationType = tfLiteNodeParameters->activation; + + // Check and Create Activation return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData); } |