diff options
Diffstat (limited to 'delegate/opaque/src/SharedFunctions.cpp')
-rw-r--r-- | delegate/opaque/src/SharedFunctions.cpp | 100 |
1 files changed, 100 insertions, 0 deletions
diff --git a/delegate/opaque/src/SharedFunctions.cpp b/delegate/opaque/src/SharedFunctions.cpp index 93eb143bd0..0a0c630697 100644 --- a/delegate/opaque/src/SharedFunctions.cpp +++ b/delegate/opaque/src/SharedFunctions.cpp @@ -100,5 +100,105 @@ TfLiteStatus ValidateFusedActivationOperator(DelegateData& delegateData, return isSupported ? kTfLiteOk : kTfLiteError; } +TfLiteOpaqueNode* GetNodeConnectedToInput(TfLiteOpaqueContext* tfLiteContext, + int32_t& connectedIndex, + int32_t inputIdx) +{ + TfLiteIntArray* executionPlan = nullptr; + if (TfLiteOpaqueContextGetExecutionPlan(tfLiteContext, &executionPlan) != kTfLiteOk) + { + TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(tfLiteContext, "TfLiteArmnnDelegate: Unable to get graph execution plan."); + return nullptr; + } + + for (int i = 0; i < executionPlan->size; ++i) + { + connectedIndex = executionPlan->data[i]; + + // If TfLite nodes can be delegated to ArmNN + TfLiteOpaqueNode* connectedNode = nullptr; + TfLiteRegistrationExternal* tfLiteRegistration = nullptr; + if (TfLiteOpaqueContextGetNodeAndRegistration( + tfLiteContext, connectedIndex, &connectedNode, &tfLiteRegistration) != kTfLiteOk) + { + TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(tfLiteContext, + "TfLiteArmnnOpaqueDelegate: Unable to get node and registration for node " + "%d.", connectedIndex); + continue; + } + int numOutputs = 0; + const int* outputTensors; + + if (TfLiteOpaqueNodeOutputs(connectedNode, &outputTensors, &numOutputs) != kTfLiteOk) + { + TF_LITE_OPAQUE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ", + connectedIndex); + continue; + } + + for (int j= 0; j < numOutputs; ++j) + { + if (outputTensors[j] == inputIdx) + { + return connectedNode; + } + } + } + // No node found so set connectedIndex to -1 + connectedIndex = -1; + return nullptr; +} + +bool WillInputBeOptimizedToConst(TfLiteOpaqueContext* tfLiteContext, int32_t inputIdx) +{ + int32_t connectedIndex; + TfLiteOpaqueNode* connectedNode = GetNodeConnectedToInput(tfLiteContext, connectedIndex, inputIdx); + + if (connectedNode) + { + TfLiteRegistrationExternal* tfLiteRegistration = nullptr; + + if (TfLiteOpaqueContextGetNodeAndRegistration(tfLiteContext, connectedIndex, &connectedNode, + &tfLiteRegistration) == kTfLiteOk) + { + switch (TfLiteRegistrationExternalGetBuiltInCode(tfLiteRegistration)) + { + case kTfLiteBuiltinDequantize: + { + auto numInputs = TfLiteOpaqueNodeNumberOfInputs(connectedNode); + if (numInputs >= 1) + { + const int* inputTensors; + if (TfLiteOpaqueNodeInputs(connectedNode, &inputTensors, &numInputs) != kTfLiteOk) + { + TF_LITE_OPAQUE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ", + connectedIndex); + return kTfLiteError; + } + const TfLiteOpaqueTensor* tfLiteInputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, + inputTensors[0]); + + // If the input to the Dequantize is a Constant then both that Constant layer and the Dequantize + // layer will be replaced by a single Constant layer containing the dequantized values. + if (IsConstantTensor(tfLiteInputTensor)) + { + return true; + } + } + break; + } + default: + { + } + } + } + } + return false; +} + } // namespace armnnDelegate |