diff options
Diffstat (limited to 'shim/sl/canonical')
-rw-r--r-- | shim/sl/canonical/ArmnnDriver.hpp | 6 | ||||
-rw-r--r-- | shim/sl/canonical/ArmnnDriverImpl.cpp | 14 | ||||
-rw-r--r-- | shim/sl/canonical/ArmnnDriverImpl.hpp | 6 | ||||
-rw-r--r-- | shim/sl/canonical/ArmnnPreparedModel.cpp | 43 |
4 files changed, 17 insertions, 52 deletions
diff --git a/shim/sl/canonical/ArmnnDriver.hpp b/shim/sl/canonical/ArmnnDriver.hpp index 877faa667e..c33c61a65b 100644 --- a/shim/sl/canonical/ArmnnDriver.hpp +++ b/shim/sl/canonical/ArmnnDriver.hpp @@ -39,12 +39,6 @@ public: ~ArmnnDriver() { VLOG(DRIVER) << "ArmnnDriver::~ArmnnDriver()"; - // Unload the networks - for (auto& netId : ArmnnDriverImpl::GetLoadedNetworks()) - { - m_Runtime->UnloadNetwork(netId); - } - ArmnnDriverImpl::ClearNetworks(); } public: diff --git a/shim/sl/canonical/ArmnnDriverImpl.cpp b/shim/sl/canonical/ArmnnDriverImpl.cpp index 3223d9e8bf..8706c382b0 100644 --- a/shim/sl/canonical/ArmnnDriverImpl.cpp +++ b/shim/sl/canonical/ArmnnDriverImpl.cpp @@ -114,11 +114,6 @@ bool ArmnnDriverImpl::ValidateDataCacheHandle(const std::vector<SharedHandle>& d return ValidateSharedHandle(dataCacheHandle[0]); } -std::vector<armnn::NetworkId>& ArmnnDriverImpl::GetLoadedNetworks() -{ - return m_NetworkIDs; -} - GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModel( const armnn::IRuntimePtr& runtime, const armnn::IGpuAccTunedParametersPtr& clTunedParameters, @@ -317,7 +312,6 @@ GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModel( options.GetBackends().end(), armnn::Compute::GpuAcc) != options.GetBackends().end()); - m_NetworkIDs.push_back(netId); auto preparedModel = std::make_shared<const ArmnnPreparedModel>(netId, runtime.get(), model, @@ -356,8 +350,6 @@ GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModel( return std::move(preparedModel); } -std::vector<armnn::NetworkId> ArmnnDriverImpl::m_NetworkIDs = {}; - GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModelFromCache( const armnn::IRuntimePtr& runtime, const armnn::IGpuAccTunedParametersPtr& clTunedParameters, @@ -537,7 +529,6 @@ GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModelFromCache( return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << message.str(); } - m_NetworkIDs.push_back(netId); return std::make_shared<const ArmnnPreparedModel>(netId, runtime.get(), options.GetRequestInputsAndOutputsDumpDir(), @@ -553,9 +544,4 @@ const Capabilities& ArmnnDriverImpl::GetCapabilities(const armnn::IRuntimePtr& r return theCapabilities; } -void ArmnnDriverImpl::ClearNetworks() -{ - m_NetworkIDs.clear(); -} - } // namespace armnn_driver diff --git a/shim/sl/canonical/ArmnnDriverImpl.hpp b/shim/sl/canonical/ArmnnDriverImpl.hpp index 836bf469cc..6af0ab285d 100644 --- a/shim/sl/canonical/ArmnnDriverImpl.hpp +++ b/shim/sl/canonical/ArmnnDriverImpl.hpp @@ -45,15 +45,9 @@ public: static const Capabilities& GetCapabilities(const armnn::IRuntimePtr& runtime); - static std::vector<armnn::NetworkId>& GetLoadedNetworks(); - - static void ClearNetworks(); - private: static bool ValidateSharedHandle(const SharedHandle& sharedHandle); static bool ValidateDataCacheHandle(const std::vector<SharedHandle>& dataCacheHandle, const size_t dataSize); - - static std::vector<armnn::NetworkId> m_NetworkIDs; }; } // namespace armnn_driver
\ No newline at end of file diff --git a/shim/sl/canonical/ArmnnPreparedModel.cpp b/shim/sl/canonical/ArmnnPreparedModel.cpp index c0ce3e41a1..54a019004c 100644 --- a/shim/sl/canonical/ArmnnPreparedModel.cpp +++ b/shim/sl/canonical/ArmnnPreparedModel.cpp @@ -93,21 +93,21 @@ bool IsPointerTypeMemory(const Request& request) { for (auto& input : request.inputs) { - if (input.lifetime == Request::Argument::LifeTime::POINTER) + if (input.lifetime != Request::Argument::LifeTime::POINTER) { - return true; + return false; } } for (auto& output: request.outputs) { - if (output.lifetime == Request::Argument::LifeTime::POINTER) + if (output.lifetime != Request::Argument::LifeTime::POINTER) { - return true; + return false; } } - return false; + return true; } } // anonymous namespace @@ -318,7 +318,8 @@ ExecutionResult<std::pair<std::vector<OutputShape>, Timing>> ArmnnPreparedModel: } VLOG(DRIVER) << "ArmnnPreparedModel::execute(): " << GetModelSummary(m_Model).c_str(); } - if (hasDeadlinePassed(deadline)) { + if (hasDeadlinePassed(deadline)) + { return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT); } @@ -381,7 +382,8 @@ ErrorStatus ArmnnPreparedModel::ExecuteGraph( VLOG(DRIVER) << "ArmnnPreparedModel::ExecuteGraph(...)"; DumpTensorsIfRequired("Input", inputTensors); - + std::vector<armnn::ImportedInputId> importedInputIds; + std::vector<armnn::ImportedOutputId> importedOutputIds; try { if (ctx.measureTimings == MeasureTiming::YES) @@ -390,24 +392,13 @@ ErrorStatus ArmnnPreparedModel::ExecuteGraph( } armnn::Status status; VLOG(DRIVER) << "ArmnnPreparedModel::ExecuteGraph m_AsyncModelExecutionEnabled false"; - - if (pointerMemory) - { - std::vector<armnn::ImportedInputId> importedInputIds; - importedInputIds = m_Runtime->ImportInputs(m_NetworkId, inputTensors, armnn::MemorySource::Malloc); - - std::vector<armnn::ImportedOutputId> importedOutputIds; - importedOutputIds = m_Runtime->ImportOutputs(m_NetworkId, outputTensors, armnn::MemorySource::Malloc); - status = m_Runtime->EnqueueWorkload(m_NetworkId, - inputTensors, - outputTensors, - importedInputIds, - importedOutputIds); - } - else - { - status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors); - } + importedInputIds = m_Runtime->ImportInputs(m_NetworkId, inputTensors, armnn::MemorySource::Malloc); + importedOutputIds = m_Runtime->ImportOutputs(m_NetworkId, outputTensors, armnn::MemorySource::Malloc); + status = m_Runtime->EnqueueWorkload(m_NetworkId, + inputTensors, + outputTensors, + importedInputIds, + importedOutputIds); if (ctx.measureTimings == MeasureTiming::YES) { @@ -430,7 +421,7 @@ ErrorStatus ArmnnPreparedModel::ExecuteGraph( return ErrorStatus::GENERAL_FAILURE; } - if (!pointerMemory) + if (!pointerMemory && (!importedInputIds.empty() || !importedOutputIds.empty())) { CommitPools(*pMemPools); } |