diff options
Diffstat (limited to 'delegate/classic/src/SharedFunctions.cpp')
-rw-r--r-- | delegate/classic/src/SharedFunctions.cpp | 77 |
1 files changed, 77 insertions, 0 deletions
diff --git a/delegate/classic/src/SharedFunctions.cpp b/delegate/classic/src/SharedFunctions.cpp index bcff3a1dd0..53136b521e 100644 --- a/delegate/classic/src/SharedFunctions.cpp +++ b/delegate/classic/src/SharedFunctions.cpp @@ -110,6 +110,83 @@ TfLiteStatus ValidateFusedActivationOperator(DelegateData& delegateData, return isSupported ? kTfLiteOk : kTfLiteError; } +TfLiteNode* GetNodeConnectedToInput(TfLiteContext* tfLiteContext, + int32_t& connectedIndex, + int32_t inputIdx) +{ + TfLiteIntArray* executionPlan = nullptr; + if (tfLiteContext->GetExecutionPlan(tfLiteContext, &executionPlan) != kTfLiteOk) + { + TF_LITE_KERNEL_LOG(tfLiteContext, "TfLiteArmnnDelegate: Unable to get graph execution plan."); + return nullptr; + } + + for (int i = 0; i < executionPlan->size; ++i) + { + connectedIndex = executionPlan->data[i]; + + // If TfLite nodes can be delegated to ArmNN + TfLiteNode* connectedNode = nullptr; + TfLiteRegistration* tfLiteRegistration = nullptr; + if (tfLiteContext->GetNodeAndRegistration( + tfLiteContext, connectedIndex, &connectedNode, &tfLiteRegistration) != kTfLiteOk) + { + TF_LITE_KERNEL_LOG(tfLiteContext, + "TfLiteArmnnDelegate: Unable to get node and registration for node %d.", + connectedIndex); + continue; + } + for (int j= 0; j < connectedNode->outputs->size; ++j) + { + if (connectedNode->outputs->data[j] == inputIdx) + { + return connectedNode; + } + } + } + // No node found so set connectedIndex to -1 + connectedIndex = -1; + return nullptr; +} + +bool WillInputBeOptimizedToConst(TfLiteContext* tfLiteContext, int32_t inputIdx) +{ + int32_t connectedIndex; + TfLiteNode* connectedNode = GetNodeConnectedToInput(tfLiteContext, connectedIndex, inputIdx); + + if (connectedNode) + { + TfLiteRegistration* tfLiteRegistration = nullptr; + + if (tfLiteContext->GetNodeAndRegistration(tfLiteContext, connectedIndex, &connectedNode, &tfLiteRegistration) + == kTfLiteOk) + { + switch (tfLiteRegistration->builtin_code) + { + case kTfLiteBuiltinDequantize: + { + if (connectedNode->inputs->size >= 1) + { + const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors; + const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[connectedNode->inputs->data[0]]; + + // If the input to the Dequantize is a Constant then both that Constant layer and the Dequantize + // layer will be replaced by a single Constant layer containing the dequantized values. + if (tflite::IsConstantTensor(&tfLiteInputTensor)) + { + return true; + } + } + break; + } + default: + { + } + } + } + } + return false; +} } // namespace armnnDelegate |