aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/armnn_delegate.cpp
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2020-10-19 17:35:30 +0100
committerSadik Armagan <sadik.armagan@arm.com>2020-10-19 16:34:15 +0000
commit3c24f43ff9afb50898d6a73ccddbc0936f72fdad (patch)
treeb4101aab6f085279cddefdc539fb3f622fc8a1b7 /delegate/src/armnn_delegate.cpp
parent418c7dd833accc061ba4cba2743631e582962915 (diff)
downloadarmnn-3c24f43ff9afb50898d6a73ccddbc0936f72fdad.tar.gz
IVGCVSW-5365 'Create the TfLite Delegate subdirectory in ArmNN'
* Created delegate sub-directory under armnn * Created Delegate, ArmnnSubgraph and DelegateOptions classes * Created cmake files. * Integrated doctest (under MIT license) as testing framework Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Change-Id: If725ebd62c40a97c783cdad22bca48709d44338c
Diffstat (limited to 'delegate/src/armnn_delegate.cpp')
-rw-r--r--delegate/src/armnn_delegate.cpp185
1 files changed, 185 insertions, 0 deletions
diff --git a/delegate/src/armnn_delegate.cpp b/delegate/src/armnn_delegate.cpp
new file mode 100644
index 0000000000..f8a8aca139
--- /dev/null
+++ b/delegate/src/armnn_delegate.cpp
@@ -0,0 +1,185 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <armnn_delegate.hpp>
+#include <algorithm>
+
+namespace armnnDelegate
+{
+
+Delegate::Delegate(armnnDelegate::DelegateOptions options)
+ : m_Runtime(nullptr, nullptr),
+ m_Options(std::move(options))
+{
+ // Create ArmNN Runtime
+ armnn::IRuntime::CreationOptions runtimeOptions;
+ m_Runtime = armnn::IRuntime::Create(runtimeOptions);
+
+ std::vector<armnn::BackendId> backends;
+
+ if (m_Runtime)
+ {
+ const armnn::BackendIdSet supportedDevices = m_Runtime->GetDeviceSpec().GetSupportedBackends();
+ for (auto& backend : m_Options.GetBackends())
+ {
+ if (std::find(supportedDevices.cbegin(), supportedDevices.cend(), backend) == supportedDevices.cend())
+ {
+ TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
+ "TfLiteArmnnDelegate: Requested unknown backend %s", backend.Get().c_str());
+ }
+ else
+ {
+ backends.push_back(backend);
+ }
+ }
+ }
+
+ if (backends.empty())
+ {
+ // No known backend specified
+ throw armnn::InvalidArgumentException("TfLiteArmnnDelegate: No known backend specified.");
+ }
+ m_Options.SetBackends(backends);
+
+ TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, "TfLiteArmnnDelegate: Created TfLite ArmNN delegate.");
+}
+
+TfLiteIntArray* Delegate::CollectOperatorsToDelegate(TfLiteContext* tfLiteContext)
+{
+ TfLiteIntArray* executionPlan = nullptr;
+ if (tfLiteContext->GetExecutionPlan(tfLiteContext, &executionPlan) != kTfLiteOk)
+ {
+ TF_LITE_KERNEL_LOG(tfLiteContext, "TfLiteArmnnDelegate: Unable to get graph execution plan.");
+ return nullptr;
+ }
+
+ // Null INetworkPtr
+ armnn::INetworkPtr nullNetworkPtr(nullptr, nullptr);
+
+ TfLiteIntArray* nodesToDelegate = TfLiteIntArrayCreate(executionPlan->size);
+ nodesToDelegate->size = 0;
+ for (int i = 0; i < executionPlan->size; ++i)
+ {
+ const int nodeIndex = executionPlan->data[i];
+
+ // If TfLite nodes can be delegated to ArmNN
+ TfLiteNode* tfLiteNode = nullptr;
+ TfLiteRegistration* tfLiteRegistration = nullptr;
+ if (tfLiteContext->GetNodeAndRegistration(
+ tfLiteContext, nodeIndex, &tfLiteNode, &tfLiteRegistration) != kTfLiteOk)
+ {
+ TF_LITE_KERNEL_LOG(tfLiteContext,
+ "TfLiteArmnnDelegate: Unable to get node and registration for node %d.",
+ nodeIndex);
+ continue;
+ }
+
+ if (ArmnnSubgraph::VisitNode(
+ nullNetworkPtr, tfLiteContext, tfLiteRegistration, tfLiteNode, nodeIndex) != kTfLiteOk)
+ {
+ // node is not supported by ArmNN
+ continue;
+ }
+
+ nodesToDelegate->data[nodesToDelegate->size++] = nodeIndex;
+ }
+
+ std::sort(&nodesToDelegate->data[0],
+ &nodesToDelegate->data[nodesToDelegate->size]);
+
+ return nodesToDelegate;
+}
+
+TfLiteDelegate* Delegate::GetDelegate()
+{
+ return &m_Delegate;
+}
+
+ArmnnSubgraph* ArmnnSubgraph::Create(TfLiteContext* tfLiteContext,
+ const TfLiteDelegateParams* parameters,
+ const Delegate* delegate)
+{
+ TfLiteIntArray* executionPlan;
+ if (tfLiteContext->GetExecutionPlan(tfLiteContext, &executionPlan) != kTfLiteOk)
+ {
+ return nullptr;
+ }
+
+ // Construct ArmNN network
+ using NetworkOptions = std::vector<armnn::BackendOptions>;
+ armnn::NetworkOptions networkOptions = {};
+ armnn::NetworkId networkId;
+ armnn::INetworkPtr network = armnn::INetwork::Create(networkOptions);
+
+ // Parse TfLite delegate nodes to ArmNN nodes
+ for (int i = 0; i < parameters->nodes_to_replace->size; ++i)
+ {
+ const int nodeIndex = parameters->nodes_to_replace->data[i];
+
+ TfLiteNode* tfLiteNode = nullptr;
+ TfLiteRegistration* tfLiteRegistration = nullptr;
+ if (tfLiteContext->GetNodeAndRegistration(
+ tfLiteContext, nodeIndex, &tfLiteNode, &tfLiteRegistration) != kTfLiteOk)
+ {
+ throw armnn::Exception("TfLiteArmnnDelegate: Unable to get node registration: " + nodeIndex);
+ }
+
+ if (VisitNode(network, tfLiteContext, tfLiteRegistration, tfLiteNode, nodeIndex) != kTfLiteOk)
+ {
+ throw armnn::Exception("TfLiteArmnnDelegate: Unable to parse node: " + nodeIndex);
+ }
+ }
+
+ // Optimise Arm NN network
+ armnn::IOptimizedNetworkPtr optNet =
+ armnn::Optimize(*network, delegate->m_Options.GetBackends(), delegate->m_Runtime->GetDeviceSpec());
+ if (!optNet)
+ {
+ // Optimize Failed
+ throw armnn::Exception("TfLiteArmnnDelegate: Unable to optimize the network!");
+ }
+ // Load graph into runtime
+ delegate->m_Runtime->LoadNetwork(networkId, std::move(optNet));
+
+ // Create a new SubGraph with networkId and runtime
+ return new ArmnnSubgraph(networkId, delegate->m_Runtime.get());
+}
+
+TfLiteStatus ArmnnSubgraph::Prepare(TfLiteContext* tfLiteContext)
+{
+ return kTfLiteOk;
+}
+
+TfLiteStatus ArmnnSubgraph::Invoke(TfLiteContext* tfLiteContext)
+{
+ /// Get the Input Tensors and OutputTensors from the context
+ /// Execute the network
+ //m_Runtime->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus ArmnnSubgraph::VisitNode(armnn::INetworkPtr& network,
+ TfLiteContext* tfLiteContext,
+ TfLiteRegistration* tfLiteRegistration,
+ TfLiteNode* tfLiteNode,
+ int nodeIndex)
+{
+ /*
+ * Take the node and check what operator it is and VisitXXXLayer()
+ * In the VisitXXXLayer() function parse TfLite node to Arm NN Layer and add it to tho network graph
+ *switch (tfLiteRegistration->builtin_code)
+ * {
+ * case kTfLiteBuiltinAbs:
+ * return VisitAbsLayer(...);
+ * ...
+ * default:
+ * return kTfLiteError;
+ * }
+ */
+ return kTfLiteError;
+}
+
+} // armnnDelegate namespace \ No newline at end of file