aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorKeith Davis <keith.davis@arm.com>2021-04-22 10:10:34 +0100
committerfinn.williams <finn.williams@arm.com>2021-05-06 19:39:39 +0000
commite813d67f86df41a238ff79b5c554ef5027f56576 (patch)
tree54c2145a78297d61e6e94676729cb2468490ade3 /include
parentd905decd256558bbee165e636ce4242ac3b9c917 (diff)
downloadarmnn-e813d67f86df41a238ff79b5c554ef5027f56576.tar.gz
IVGCVSW-5813 Add Async Queue to IRuntime
Signed-off-by: Keith Davis <keith.davis@arm.com> Change-Id: Icc0d131c8ee2e9748e2f14762a75962b39c10f9d
Diffstat (limited to 'include')
-rw-r--r--include/armnn/IAsyncExecutionCallback.hpp43
-rw-r--r--include/armnn/IRuntime.hpp23
-rw-r--r--include/armnn/Types.hpp22
3 files changed, 81 insertions, 7 deletions
diff --git a/include/armnn/IAsyncExecutionCallback.hpp b/include/armnn/IAsyncExecutionCallback.hpp
new file mode 100644
index 0000000000..045ec4581f
--- /dev/null
+++ b/include/armnn/IAsyncExecutionCallback.hpp
@@ -0,0 +1,43 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "Types.hpp"
+
+namespace armnn
+{
+
+namespace experimental
+{
+
+class IAsyncExecutionCallback;
+using IAsyncExecutionCallbackPtr = std::shared_ptr<IAsyncExecutionCallback>;
+
+class IAsyncExecutionCallback
+{
+public:
+ virtual ~IAsyncExecutionCallback() {};
+
+ // Notify the AsyncExecutionCallback object of the armnn execution status
+ virtual void Notify(armnn::Status status, InferenceTimingPair timeTaken) = 0;
+
+ // Block the calling thread until the AsyncExecutionCallback object allows it to proceed
+ virtual void Wait() const = 0;
+
+ // Retrieve the ArmNN Status from the AsyncExecutionCallback that has been notified
+ virtual armnn::Status GetStatus() const = 0;
+
+ // Retrieve the start time before executing the inference
+ virtual HighResolutionClock GetStartTime() const = 0;
+
+ // Retrieve the time after executing the inference
+ virtual HighResolutionClock GetEndTime() const = 0;
+
+};
+
+} // experimental
+
+} // namespace armnn
diff --git a/include/armnn/IRuntime.hpp b/include/armnn/IRuntime.hpp
index 55c57974dc..f296a5f564 100644
--- a/include/armnn/IRuntime.hpp
+++ b/include/armnn/IRuntime.hpp
@@ -8,6 +8,7 @@
#include "INetwork.hpp"
#include "IProfiler.hpp"
#include "IWorkingMemHandle.hpp"
+#include "IAsyncExecutionCallback.hpp"
#include "Tensor.hpp"
#include "Types.hpp"
#include "TypesUtils.hpp"
@@ -31,20 +32,24 @@ struct INetworkProperties
ARMNN_DEPRECATED_MSG("Please use INetworkProperties constructor with MemorySource argument")
INetworkProperties(bool importEnabled = false,
bool exportEnabled = false,
- bool asyncEnabled = false)
+ bool asyncEnabled = false,
+ size_t numThreads = 0)
: m_ImportEnabled(importEnabled)
, m_ExportEnabled(exportEnabled)
, m_AsyncEnabled(asyncEnabled)
+ , m_NumThreads(numThreads)
, m_InputSource(MemorySource::Undefined)
, m_OutputSource(MemorySource::Undefined)
{}
INetworkProperties(bool asyncEnabled,
MemorySource m_InputSource,
- MemorySource m_OutputSource)
+ MemorySource m_OutputSource,
+ size_t numThreads = 0)
: m_ImportEnabled(m_InputSource != MemorySource::Undefined)
, m_ExportEnabled(m_OutputSource != MemorySource::Undefined)
, m_AsyncEnabled(asyncEnabled)
+ , m_NumThreads(numThreads)
, m_InputSource(m_InputSource)
, m_OutputSource(m_OutputSource)
{}
@@ -54,7 +59,9 @@ struct INetworkProperties
/// Deprecated and will be removed in future release.
const bool m_ExportEnabled;
- const bool m_AsyncEnabled;
+ const bool m_AsyncEnabled;
+ const size_t m_NumThreads;
+
const MemorySource m_InputSource;
const MemorySource m_OutputSource;
@@ -184,6 +191,16 @@ public:
const InputTensors& inputTensors,
const OutputTensors& outputTensors);
+ /// This is an experimental function
+ /// Schedule a thread safe execution by taking the input tensors and an execution priority for Quality of Service.
+ /// The output tensors will then be filled and the callback object will notify that the execution has either
+ /// succeeded or failed.
+ void Schedule(NetworkId networkId,
+ const InputTensors& inputTensors,
+ const OutputTensors& outputTensors,
+ const QosExecPriority priority,
+ std::shared_ptr<IAsyncExecutionCallback> callback);
+
/// Unloads a network from the IRuntime.
/// At the moment this only removes the network from the m_Impl->m_Network.
/// This might need more work in the future to be AndroidNN compliant.
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index bc41003c57..9e46d08501 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -8,6 +8,7 @@
#include <functional>
#include <memory>
#include <stdint.h>
+#include <chrono>
#include "BackendId.hpp"
#include "Exceptions.hpp"
#include "Deprecated.hpp"
@@ -20,6 +21,9 @@ constexpr unsigned int MaxNumOfTensorDimensions = 5U;
/// The lowest performance data capture interval we support is 10 miliseconds.
constexpr unsigned int LOWEST_CAPTURE_PERIOD = 10000u;
+/// Variable to control expire rate of priority queue
+constexpr unsigned int EXPIRE_RATE = 3U;
+
/// @enum Status enumeration
/// @var Status::Successful
/// @var Status::Failure
@@ -31,14 +35,14 @@ enum class Status
enum class DataType
{
- Float16 = 0,
- Float32 = 1,
+ Float16 = 0,
+ Float32 = 1,
QAsymmU8 = 2,
Signed32 = 3,
- Boolean = 4,
+ Boolean = 4,
QSymmS16 = 5,
QuantizedSymm8PerAxis ARMNN_DEPRECATED_ENUM_MSG("Per Axis property inferred by number of scales in TensorInfo") = 6,
- QSymmS8 = 7,
+ QSymmS8 = 7,
QAsymmS8 = 8,
BFloat16 = 9,
Signed64 = 10,
@@ -53,6 +57,13 @@ enum class DataLayout
NHWC = 2
};
+enum class QosExecPriority
+{
+ Low = 0,
+ Medium = 1,
+ High = 2
+};
+
enum class ActivationFunction
{
Sigmoid = 0,
@@ -304,6 +315,9 @@ class ITensorHandle;
/// @param tensorHandle - TensorHandle for the input tensor to the Debug layer
using DebugCallbackFunction = std::function<void(LayerGuid guid, unsigned int slotIndex, ITensorHandle* tensorHandle)>;
+/// Define a timer and associated inference ID for recording execution times
+using HighResolutionClock = std::chrono::high_resolution_clock::time_point;
+using InferenceTimingPair = std::pair<HighResolutionClock, HighResolutionClock>;
namespace profiling
{