diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2020-10-19 17:35:30 +0100 |
---|---|---|
committer | Sadik Armagan <sadik.armagan@arm.com> | 2020-10-19 16:34:15 +0000 |
commit | 3c24f43ff9afb50898d6a73ccddbc0936f72fdad (patch) | |
tree | b4101aab6f085279cddefdc539fb3f622fc8a1b7 /delegate/src/armnn_delegate.cpp | |
parent | 418c7dd833accc061ba4cba2743631e582962915 (diff) | |
download | armnn-3c24f43ff9afb50898d6a73ccddbc0936f72fdad.tar.gz |
IVGCVSW-5365 'Create the TfLite Delegate subdirectory in ArmNN'
* Created delegate sub-directory under armnn
* Created Delegate, ArmnnSubgraph and DelegateOptions classes
* Created cmake files.
* Integrated doctest (under MIT license) as testing framework
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: If725ebd62c40a97c783cdad22bca48709d44338c
Diffstat (limited to 'delegate/src/armnn_delegate.cpp')
-rw-r--r-- | delegate/src/armnn_delegate.cpp | 185 |
1 files changed, 185 insertions, 0 deletions
diff --git a/delegate/src/armnn_delegate.cpp b/delegate/src/armnn_delegate.cpp new file mode 100644 index 0000000000..f8a8aca139 --- /dev/null +++ b/delegate/src/armnn_delegate.cpp @@ -0,0 +1,185 @@ +// +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include <armnn_delegate.hpp> +#include <algorithm> + +namespace armnnDelegate +{ + +Delegate::Delegate(armnnDelegate::DelegateOptions options) + : m_Runtime(nullptr, nullptr), + m_Options(std::move(options)) +{ + // Create ArmNN Runtime + armnn::IRuntime::CreationOptions runtimeOptions; + m_Runtime = armnn::IRuntime::Create(runtimeOptions); + + std::vector<armnn::BackendId> backends; + + if (m_Runtime) + { + const armnn::BackendIdSet supportedDevices = m_Runtime->GetDeviceSpec().GetSupportedBackends(); + for (auto& backend : m_Options.GetBackends()) + { + if (std::find(supportedDevices.cbegin(), supportedDevices.cend(), backend) == supportedDevices.cend()) + { + TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, + "TfLiteArmnnDelegate: Requested unknown backend %s", backend.Get().c_str()); + } + else + { + backends.push_back(backend); + } + } + } + + if (backends.empty()) + { + // No known backend specified + throw armnn::InvalidArgumentException("TfLiteArmnnDelegate: No known backend specified."); + } + m_Options.SetBackends(backends); + + TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, "TfLiteArmnnDelegate: Created TfLite ArmNN delegate."); +} + +TfLiteIntArray* Delegate::CollectOperatorsToDelegate(TfLiteContext* tfLiteContext) +{ + TfLiteIntArray* executionPlan = nullptr; + if (tfLiteContext->GetExecutionPlan(tfLiteContext, &executionPlan) != kTfLiteOk) + { + TF_LITE_KERNEL_LOG(tfLiteContext, "TfLiteArmnnDelegate: Unable to get graph execution plan."); + return nullptr; + } + + // Null INetworkPtr + armnn::INetworkPtr nullNetworkPtr(nullptr, nullptr); + + TfLiteIntArray* nodesToDelegate = TfLiteIntArrayCreate(executionPlan->size); + nodesToDelegate->size = 0; + for (int i = 0; i < executionPlan->size; ++i) + { + const int nodeIndex = executionPlan->data[i]; + + // If TfLite nodes can be delegated to ArmNN + TfLiteNode* tfLiteNode = nullptr; + TfLiteRegistration* tfLiteRegistration = nullptr; + if (tfLiteContext->GetNodeAndRegistration( + tfLiteContext, nodeIndex, &tfLiteNode, &tfLiteRegistration) != kTfLiteOk) + { + TF_LITE_KERNEL_LOG(tfLiteContext, + "TfLiteArmnnDelegate: Unable to get node and registration for node %d.", + nodeIndex); + continue; + } + + if (ArmnnSubgraph::VisitNode( + nullNetworkPtr, tfLiteContext, tfLiteRegistration, tfLiteNode, nodeIndex) != kTfLiteOk) + { + // node is not supported by ArmNN + continue; + } + + nodesToDelegate->data[nodesToDelegate->size++] = nodeIndex; + } + + std::sort(&nodesToDelegate->data[0], + &nodesToDelegate->data[nodesToDelegate->size]); + + return nodesToDelegate; +} + +TfLiteDelegate* Delegate::GetDelegate() +{ + return &m_Delegate; +} + +ArmnnSubgraph* ArmnnSubgraph::Create(TfLiteContext* tfLiteContext, + const TfLiteDelegateParams* parameters, + const Delegate* delegate) +{ + TfLiteIntArray* executionPlan; + if (tfLiteContext->GetExecutionPlan(tfLiteContext, &executionPlan) != kTfLiteOk) + { + return nullptr; + } + + // Construct ArmNN network + using NetworkOptions = std::vector<armnn::BackendOptions>; + armnn::NetworkOptions networkOptions = {}; + armnn::NetworkId networkId; + armnn::INetworkPtr network = armnn::INetwork::Create(networkOptions); + + // Parse TfLite delegate nodes to ArmNN nodes + for (int i = 0; i < parameters->nodes_to_replace->size; ++i) + { + const int nodeIndex = parameters->nodes_to_replace->data[i]; + + TfLiteNode* tfLiteNode = nullptr; + TfLiteRegistration* tfLiteRegistration = nullptr; + if (tfLiteContext->GetNodeAndRegistration( + tfLiteContext, nodeIndex, &tfLiteNode, &tfLiteRegistration) != kTfLiteOk) + { + throw armnn::Exception("TfLiteArmnnDelegate: Unable to get node registration: " + nodeIndex); + } + + if (VisitNode(network, tfLiteContext, tfLiteRegistration, tfLiteNode, nodeIndex) != kTfLiteOk) + { + throw armnn::Exception("TfLiteArmnnDelegate: Unable to parse node: " + nodeIndex); + } + } + + // Optimise Arm NN network + armnn::IOptimizedNetworkPtr optNet = + armnn::Optimize(*network, delegate->m_Options.GetBackends(), delegate->m_Runtime->GetDeviceSpec()); + if (!optNet) + { + // Optimize Failed + throw armnn::Exception("TfLiteArmnnDelegate: Unable to optimize the network!"); + } + // Load graph into runtime + delegate->m_Runtime->LoadNetwork(networkId, std::move(optNet)); + + // Create a new SubGraph with networkId and runtime + return new ArmnnSubgraph(networkId, delegate->m_Runtime.get()); +} + +TfLiteStatus ArmnnSubgraph::Prepare(TfLiteContext* tfLiteContext) +{ + return kTfLiteOk; +} + +TfLiteStatus ArmnnSubgraph::Invoke(TfLiteContext* tfLiteContext) +{ + /// Get the Input Tensors and OutputTensors from the context + /// Execute the network + //m_Runtime->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors); + + return kTfLiteOk; +} + +TfLiteStatus ArmnnSubgraph::VisitNode(armnn::INetworkPtr& network, + TfLiteContext* tfLiteContext, + TfLiteRegistration* tfLiteRegistration, + TfLiteNode* tfLiteNode, + int nodeIndex) +{ + /* + * Take the node and check what operator it is and VisitXXXLayer() + * In the VisitXXXLayer() function parse TfLite node to Arm NN Layer and add it to tho network graph + *switch (tfLiteRegistration->builtin_code) + * { + * case kTfLiteBuiltinAbs: + * return VisitAbsLayer(...); + * ... + * default: + * return kTfLiteError; + * } + */ + return kTfLiteError; +} + +} // armnnDelegate namespace
\ No newline at end of file |