diff options
Diffstat (limited to 'ArmnnPreparedModel_1_2.hpp')
-rw-r--r-- | ArmnnPreparedModel_1_2.hpp | 77 |
1 files changed, 68 insertions, 9 deletions
diff --git a/ArmnnPreparedModel_1_2.hpp b/ArmnnPreparedModel_1_2.hpp index 13d7494e..57deb98c 100644 --- a/ArmnnPreparedModel_1_2.hpp +++ b/ArmnnPreparedModel_1_2.hpp @@ -12,6 +12,7 @@ #include <NeuralNetworks.h> #include <armnn/ArmNN.hpp> +#include <armnn/Threadpool.hpp> #include <string> #include <vector> @@ -44,7 +45,21 @@ public: armnn::IRuntime* runtime, const HalModel& model, const std::string& requestInputsAndOutputsDumpDir, - const bool gpuProfilingEnabled); + const bool gpuProfilingEnabled, + const bool asyncModelExecutionEnabled = false, + const unsigned int numberOfThreads = 1, + const bool importEnabled = false, + const bool exportEnabled = false); + + ArmnnPreparedModel_1_2(armnn::NetworkId networkId, + armnn::IRuntime* runtime, + const std::string& requestInputsAndOutputsDumpDir, + const bool gpuProfilingEnabled, + const bool asyncModelExecutionEnabled = false, + const unsigned int numberOfThreads = 1, + const bool importEnabled = false, + const bool exportEnabled = false, + const bool preparedFromCache = false); virtual ~ArmnnPreparedModel_1_2(); @@ -73,9 +88,38 @@ public: /// Executes this model with dummy inputs (e.g. all zeroes). /// \return false on failure, otherwise true - bool ExecuteWithDummyInputs(); + bool ExecuteWithDummyInputs(unsigned int numInputs, unsigned int numOutputs); private: + + template<typename CallbackContext> + class ArmnnThreadPoolCallback_1_2 : public armnn::IAsyncExecutionCallback + { + public: + ArmnnThreadPoolCallback_1_2(ArmnnPreparedModel_1_2<HalVersion>* model, + std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, + std::vector<V1_2::OutputShape> outputShapes, + std::shared_ptr<armnn::InputTensors>& inputTensors, + std::shared_ptr<armnn::OutputTensors>& outputTensors, + CallbackContext callbackContext) : + m_Model(model), + m_MemPools(pMemPools), + m_OutputShapes(outputShapes), + m_InputTensors(inputTensors), + m_OutputTensors(outputTensors), + m_CallbackContext(callbackContext) + {} + + void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override; + + ArmnnPreparedModel_1_2<HalVersion>* m_Model; + std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools; + std::vector<V1_2::OutputShape> m_OutputShapes; + std::shared_ptr<armnn::InputTensors> m_InputTensors; + std::shared_ptr<armnn::OutputTensors> m_OutputTensors; + CallbackContext m_CallbackContext; + }; + Return<V1_0::ErrorStatus> Execute(const V1_0::Request& request, V1_2::MeasureTiming measureTiming, CallbackAsync_1_2 callback); @@ -101,17 +145,32 @@ private: template <typename TensorBindingCollection> void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings); - armnn::NetworkId m_NetworkId; - armnn::IRuntime* m_Runtime; - V1_2::Model m_Model; + /// schedule the graph prepared from the request for execution + template<typename CallbackContext> + void ScheduleGraphForExecution( + std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, + std::shared_ptr<armnn::InputTensors>& inputTensors, + std::shared_ptr<armnn::OutputTensors>& outputTensors, + CallbackContext m_CallbackContext); + + armnn::NetworkId m_NetworkId; + armnn::IRuntime* m_Runtime; + V1_2::Model m_Model; // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads // It is specific to this class, so it is declared as static here static RequestThread<ArmnnPreparedModel_1_2, HalVersion, - CallbackContext_1_2> m_RequestThread; - uint32_t m_RequestCount; - const std::string& m_RequestInputsAndOutputsDumpDir; - const bool m_GpuProfilingEnabled; + CallbackContext_1_2> m_RequestThread; + uint32_t m_RequestCount; + const std::string& m_RequestInputsAndOutputsDumpDir; + const bool m_GpuProfilingEnabled; + // Static to allow sharing of threadpool between ArmnnPreparedModel instances + static std::unique_ptr<armnn::Threadpool> m_Threadpool; + std::shared_ptr<IWorkingMemHandle> m_WorkingMemHandle; + const bool m_AsyncModelExecutionEnabled; + const bool m_EnableImport; + const bool m_EnableExport; + const bool m_PreparedFromCache; }; } |