From e813d67f86df41a238ff79b5c554ef5027f56576 Mon Sep 17 00:00:00 2001 From: Keith Davis Date: Thu, 22 Apr 2021 10:10:34 +0100 Subject: IVGCVSW-5813 Add Async Queue to IRuntime Signed-off-by: Keith Davis Change-Id: Icc0d131c8ee2e9748e2f14762a75962b39c10f9d --- include/armnn/IAsyncExecutionCallback.hpp | 43 +++++++++++++++++++++++++++++++ include/armnn/IRuntime.hpp | 23 ++++++++++++++--- include/armnn/Types.hpp | 22 +++++++++++++--- 3 files changed, 81 insertions(+), 7 deletions(-) create mode 100644 include/armnn/IAsyncExecutionCallback.hpp (limited to 'include/armnn') diff --git a/include/armnn/IAsyncExecutionCallback.hpp b/include/armnn/IAsyncExecutionCallback.hpp new file mode 100644 index 0000000000..045ec4581f --- /dev/null +++ b/include/armnn/IAsyncExecutionCallback.hpp @@ -0,0 +1,43 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "Types.hpp" + +namespace armnn +{ + +namespace experimental +{ + +class IAsyncExecutionCallback; +using IAsyncExecutionCallbackPtr = std::shared_ptr; + +class IAsyncExecutionCallback +{ +public: + virtual ~IAsyncExecutionCallback() {}; + + // Notify the AsyncExecutionCallback object of the armnn execution status + virtual void Notify(armnn::Status status, InferenceTimingPair timeTaken) = 0; + + // Block the calling thread until the AsyncExecutionCallback object allows it to proceed + virtual void Wait() const = 0; + + // Retrieve the ArmNN Status from the AsyncExecutionCallback that has been notified + virtual armnn::Status GetStatus() const = 0; + + // Retrieve the start time before executing the inference + virtual HighResolutionClock GetStartTime() const = 0; + + // Retrieve the time after executing the inference + virtual HighResolutionClock GetEndTime() const = 0; + +}; + +} // experimental + +} // namespace armnn diff --git a/include/armnn/IRuntime.hpp b/include/armnn/IRuntime.hpp index 55c57974dc..f296a5f564 100644 --- a/include/armnn/IRuntime.hpp +++ b/include/armnn/IRuntime.hpp @@ -8,6 +8,7 @@ #include "INetwork.hpp" #include "IProfiler.hpp" #include "IWorkingMemHandle.hpp" +#include "IAsyncExecutionCallback.hpp" #include "Tensor.hpp" #include "Types.hpp" #include "TypesUtils.hpp" @@ -31,20 +32,24 @@ struct INetworkProperties ARMNN_DEPRECATED_MSG("Please use INetworkProperties constructor with MemorySource argument") INetworkProperties(bool importEnabled = false, bool exportEnabled = false, - bool asyncEnabled = false) + bool asyncEnabled = false, + size_t numThreads = 0) : m_ImportEnabled(importEnabled) , m_ExportEnabled(exportEnabled) , m_AsyncEnabled(asyncEnabled) + , m_NumThreads(numThreads) , m_InputSource(MemorySource::Undefined) , m_OutputSource(MemorySource::Undefined) {} INetworkProperties(bool asyncEnabled, MemorySource m_InputSource, - MemorySource m_OutputSource) + MemorySource m_OutputSource, + size_t numThreads = 0) : m_ImportEnabled(m_InputSource != MemorySource::Undefined) , m_ExportEnabled(m_OutputSource != MemorySource::Undefined) , m_AsyncEnabled(asyncEnabled) + , m_NumThreads(numThreads) , m_InputSource(m_InputSource) , m_OutputSource(m_OutputSource) {} @@ -54,7 +59,9 @@ struct INetworkProperties /// Deprecated and will be removed in future release. const bool m_ExportEnabled; - const bool m_AsyncEnabled; + const bool m_AsyncEnabled; + const size_t m_NumThreads; + const MemorySource m_InputSource; const MemorySource m_OutputSource; @@ -184,6 +191,16 @@ public: const InputTensors& inputTensors, const OutputTensors& outputTensors); + /// This is an experimental function + /// Schedule a thread safe execution by taking the input tensors and an execution priority for Quality of Service. + /// The output tensors will then be filled and the callback object will notify that the execution has either + /// succeeded or failed. + void Schedule(NetworkId networkId, + const InputTensors& inputTensors, + const OutputTensors& outputTensors, + const QosExecPriority priority, + std::shared_ptr callback); + /// Unloads a network from the IRuntime. /// At the moment this only removes the network from the m_Impl->m_Network. /// This might need more work in the future to be AndroidNN compliant. diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp index bc41003c57..9e46d08501 100644 --- a/include/armnn/Types.hpp +++ b/include/armnn/Types.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include "BackendId.hpp" #include "Exceptions.hpp" #include "Deprecated.hpp" @@ -20,6 +21,9 @@ constexpr unsigned int MaxNumOfTensorDimensions = 5U; /// The lowest performance data capture interval we support is 10 miliseconds. constexpr unsigned int LOWEST_CAPTURE_PERIOD = 10000u; +/// Variable to control expire rate of priority queue +constexpr unsigned int EXPIRE_RATE = 3U; + /// @enum Status enumeration /// @var Status::Successful /// @var Status::Failure @@ -31,14 +35,14 @@ enum class Status enum class DataType { - Float16 = 0, - Float32 = 1, + Float16 = 0, + Float32 = 1, QAsymmU8 = 2, Signed32 = 3, - Boolean = 4, + Boolean = 4, QSymmS16 = 5, QuantizedSymm8PerAxis ARMNN_DEPRECATED_ENUM_MSG("Per Axis property inferred by number of scales in TensorInfo") = 6, - QSymmS8 = 7, + QSymmS8 = 7, QAsymmS8 = 8, BFloat16 = 9, Signed64 = 10, @@ -53,6 +57,13 @@ enum class DataLayout NHWC = 2 }; +enum class QosExecPriority +{ + Low = 0, + Medium = 1, + High = 2 +}; + enum class ActivationFunction { Sigmoid = 0, @@ -304,6 +315,9 @@ class ITensorHandle; /// @param tensorHandle - TensorHandle for the input tensor to the Debug layer using DebugCallbackFunction = std::function; +/// Define a timer and associated inference ID for recording execution times +using HighResolutionClock = std::chrono::high_resolution_clock::time_point; +using InferenceTimingPair = std::pair; namespace profiling { -- cgit v1.2.1