// // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "LoadedNetwork.hpp" #include "DeviceSpec.hpp" #include #include #include #include #include #include #include #include #include #include namespace armnn { using LoadedNetworks = std::unordered_map>; using IReportStructure = arm::pipe::IReportStructure; using IInitialiseProfilingService = arm::pipe::IInitialiseProfilingService; struct RuntimeImpl final : public IReportStructure, public IInitialiseProfilingService { public: /// Loads a complete network into the Runtime. /// @param [out] networkIdOut - Unique identifier for the network is returned in this reference. /// @param [in] network - Complete network to load into the Runtime. /// The runtime takes ownership of the network once passed in. /// @return armnn::Status Status LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network); /// Load a complete network into the IRuntime. /// @param [out] networkIdOut Unique identifier for the network is returned in this reference. /// @param [in] network Complete network to load into the IRuntime. /// @param [out] errorMessage Error message if there were any errors. /// The runtime takes ownership of the network once passed in. /// @return armnn::Status Status LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network, std::string& errorMessage); Status LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network, std::string& errorMessage, const INetworkProperties& networkProperties); armnn::TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const; armnn::TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const; std::vector ImportInputs(NetworkId networkId, const InputTensors& inputTensors, MemorySource forceImportMemorySource = MemorySource::Undefined); std::vector ImportOutputs(NetworkId networkId, const OutputTensors& outputTensors, MemorySource forceImportMemorySource = MemorySource::Undefined); void ClearImportedInputs(NetworkId networkId, const std::vector inputIds); void ClearImportedOutputs(NetworkId networkId, const std::vector outputIds); // Evaluates network using input in inputTensors, outputs filled into outputTensors. Status EnqueueWorkload(NetworkId networkId, const InputTensors& inputTensors, const OutputTensors& outputTensors, std::vector preImportedInputIds = {}, std::vector preImportedOutputIds = {}); /// This is an experimental function. /// Evaluates a network using input in inputTensors and outputs filled into outputTensors. /// This function performs a thread safe execution of the network. Returns once execution is complete. /// Will block until this and any other thread using the same workingMem object completes. Status Execute(IWorkingMemHandle& workingMemHandle, const InputTensors& inputTensors, const OutputTensors& outputTensors, std::vector preImportedInputs, std::vector preImportedOutputs); /// Unloads a network from the Runtime. /// At the moment this only removes the network from the m_Impl->m_Network. /// This might need more work in the future to be AndroidNN compliant. /// @param [in] networkId Unique identifier for the network to be unloaded. Generated in LoadNetwork(). /// @return armnn::Status Status UnloadNetwork(NetworkId networkId); const IDeviceSpec& GetDeviceSpec() const { return m_DeviceSpec; } /// Gets the profiler corresponding to the given network id. /// @param networkId The id of the network for which to get the profile. /// @return A pointer to the requested profiler, or nullptr if not found. const std::shared_ptr GetProfiler(NetworkId networkId) const; /// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have /// overlapped Execution by calling this function from different threads. std::unique_ptr CreateWorkingMemHandle(NetworkId networkId); /// Registers a callback function to debug layers performing custom computations on intermediate tensors. /// @param networkId The id of the network to register the callback. /// @param func callback function to pass to the debug layer. void RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func); /// Creates a runtime for workload execution. RuntimeImpl(const IRuntime::CreationOptions& options); ~RuntimeImpl(); //NOTE: we won't need the profiling service reference but it is good to pass the service // in this way to facilitate other implementations down the road void ReportStructure(arm::pipe::IProfilingService& profilingService) override; void InitialiseProfilingService(arm::pipe::IProfilingService& profilingService) override; private: friend void RuntimeLoadedNetworksReserve(RuntimeImpl* runtime); // See RuntimeTests.cpp friend arm::pipe::IProfilingService& GetProfilingService(RuntimeImpl* runtime); // See RuntimeTests.cpp int GenerateNetworkId(); LoadedNetwork* GetLoadedNetworkPtr(NetworkId networkId) const; template void LoadedNetworkFuncSafe(NetworkId networkId, Func f) { #if !defined(ARMNN_DISABLE_THREADS) std::lock_guard lockGuard(m_Mutex); #endif auto iter = m_LoadedNetworks.find(networkId); if (iter != m_LoadedNetworks.end()) { f(iter->second.get()); } } /// Loads any available/compatible dynamic backend in the runtime. void LoadDynamicBackends(const std::string& overrideBackendPath); #if !defined(ARMNN_DISABLE_THREADS) mutable std::mutex m_Mutex; #endif /// Map of Loaded Networks with associated GUID as key LoadedNetworks m_LoadedNetworks; std::unordered_map m_BackendContexts; int m_NetworkIdCounter; DeviceSpec m_DeviceSpec; /// List of dynamic backends loaded in the runtime std::vector m_DynamicBackends; /// Profiling Service Instance std::unique_ptr m_ProfilingService; }; } // namespace armnn