diff options
author | Ryan OShea <ryan.oshea3@arm.com> | 2023-01-13 10:19:20 +0000 |
---|---|---|
committer | Colm Donelan <colm.donelan@arm.com> | 2023-01-27 21:03:23 +0000 |
commit | 3ad2e14333fa0ffebe373b05ce582068c4c8f5f0 (patch) | |
tree | c597684297c84ffb71871d96a2d6c778559074c0 /delegate/src/FullyConnected.hpp | |
parent | 3811a97033be66f7a5d8fc3340b0899e0b60f737 (diff) | |
download | armnn-3ad2e14333fa0ffebe373b05ce582068c4c8f5f0.tar.gz |
IVGCVSW-7450 Fix delegate fallback when fused activation is unsupported
In layers that support fused activations, we check for activation
layer support after we already create the base layer. This breaks
the fallback as we already added the base layer to the graph.
* Creates ValidateFusedActivation shared function
* Moves Activation validation higher in the VisitFunction
Signed-off-by: Ryan OShea <ryan.oshea3@arm.com>
Change-Id: I239af360923f695fc374ddeaeefa24c062eaf9e8
Diffstat (limited to 'delegate/src/FullyConnected.hpp')
-rw-r--r-- | delegate/src/FullyConnected.hpp | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/delegate/src/FullyConnected.hpp b/delegate/src/FullyConnected.hpp index 2243ad0e0c..ee553ce81c 100644 --- a/delegate/src/FullyConnected.hpp +++ b/delegate/src/FullyConnected.hpp @@ -57,6 +57,22 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, armnn::TensorInfo weightsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteWeightsTensor); const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true); + // Check that we support fused activation before we attempt to create a layer + auto* tfLiteNodeParameters = reinterpret_cast<TfLiteFullyConnectedParams *>(tfLiteNode->builtin_data); + TfLiteFusedActivation activationType; + if (tfLiteNodeParameters) + { + activationType = tfLiteNodeParameters->activation; + + const armnn::TensorInfo& activationOutputInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true); + TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo, + outputTensorInfo, activationType); + if(activationStatus != kTfLiteOk) + { + return kTfLiteError; + } + } + // Fully Connected Layer accepts two dimensional weights input int32_t weightsDimension = static_cast<int32_t>(weightsTensorInfo.GetNumDimensions()); if (weightsDimension != 2) @@ -221,9 +237,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, { Connect(layer, tfLiteNode, delegateData); } - - auto* tfLiteNodeParameters = reinterpret_cast<TfLiteFullyConnectedParams*>(tfLiteNode->builtin_data); - + if (outputTensorInfo.GetNumDimensions() > 2) { layer = AddReshapeLayer(tfLiteContext, tfLiteNode, layer, reshapedOutputTensorInfo, outputTensorInfo, @@ -244,8 +258,8 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, // No Activation return kTfLiteOk; } - // Check Activation - TfLiteFusedActivation activationType = tfLiteNodeParameters->activation; + + // Check and Create Activation return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData); } |