From 2f2e0be2f5389c24ac74099a3300aa983bc4adcb Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Tue, 2 Aug 2022 09:17:23 +0100 Subject: IVGCVSW-7063 'Support Library NNAPI Caching' * Fixed caching issue. Signed-off-by: Sadik Armagan Change-Id: Ic7b3e0bd4438b2fd1b3dbfa86b6c89d625bbf9dd --- shim/sl/CMakeLists.txt | 2 - shim/sl/canonical/ArmnnDriverImpl.cpp | 91 +++++++++++++++++++--------------- shim/sl/canonical/CacheDataHandler.cpp | 69 -------------------------- shim/sl/canonical/CacheDataHandler.hpp | 64 ------------------------ src/armnnDeserializer/Deserializer.cpp | 2 + 5 files changed, 52 insertions(+), 176 deletions(-) delete mode 100644 shim/sl/canonical/CacheDataHandler.cpp delete mode 100644 shim/sl/canonical/CacheDataHandler.hpp diff --git a/shim/sl/CMakeLists.txt b/shim/sl/CMakeLists.txt index 81c97f93df..0ba6390b41 100644 --- a/shim/sl/CMakeLists.txt +++ b/shim/sl/CMakeLists.txt @@ -474,8 +474,6 @@ list(APPEND armnn_support_library_sources canonical/ArmnnDriver.hpp canonical/ArmnnDriverImpl.cpp canonical/ArmnnDriverImpl.hpp - canonical/CacheDataHandler.cpp - canonical/CacheDataHandler.hpp canonical/CanonicalUtils.cpp canonical/CanonicalUtils.hpp canonical/ConversionUtils.cpp diff --git a/shim/sl/canonical/ArmnnDriverImpl.cpp b/shim/sl/canonical/ArmnnDriverImpl.cpp index 8706c382b0..0c98a16138 100644 --- a/shim/sl/canonical/ArmnnDriverImpl.cpp +++ b/shim/sl/canonical/ArmnnDriverImpl.cpp @@ -5,7 +5,6 @@ #include "ArmnnDriverImpl.hpp" #include "ArmnnPreparedModel.hpp" -#include "CacheDataHandler.hpp" #include "ModelToINetworkTransformer.hpp" #include "SystemPropertiesUtils.hpp" @@ -62,6 +61,16 @@ Capabilities GenerateCapabilities() /* whilePerformance */ defaultPerfInfo }; } +size_t Hash(std::vector& cacheData) +{ + std::size_t hash = cacheData.size(); + for (auto& i : cacheData) + { + hash = ((hash << 5) - hash) + i; + } + return hash; +} + } // anonymous namespace using namespace android::nn; @@ -87,33 +96,6 @@ bool ArmnnDriverImpl::ValidateSharedHandle(const SharedHandle& sharedHandle) return valid; } -bool ArmnnDriverImpl::ValidateDataCacheHandle(const std::vector& dataCacheHandle, const size_t dataSize) -{ - bool valid = true; - // DataCacheHandle size should always be 1 for ArmNN model - if (dataCacheHandle.size() != 1) - { - return !valid; - } - - if (dataSize == 0) - { - return !valid; - } - - struct stat statBuffer; - if (fstat(*dataCacheHandle[0], &statBuffer) == 0) - { - unsigned long bufferSize = statBuffer.st_size; - if (bufferSize != dataSize) - { - return !valid; - } - } - - return ValidateSharedHandle(dataCacheHandle[0]); -} - GeneralResult ArmnnDriverImpl::PrepareArmnnModel( const armnn::IRuntimePtr& runtime, const armnn::IGpuAccTunedParametersPtr& clTunedParameters, @@ -274,8 +256,7 @@ GeneralResult ArmnnDriverImpl::PrepareArmnnModel( size_t hashValue = 0; if (dataCacheHandle.size() == 1 ) { - write(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size()); - hashValue = CacheDataHandlerInstance().Hash(dataCacheData); + hashValue = Hash(dataCacheData); } // Cache the model data @@ -296,16 +277,20 @@ GeneralResult ArmnnDriverImpl::PrepareArmnnModel( { std::vector modelData(modelDataSize); pread(*modelCacheHandle[i], modelData.data(), modelData.size(), 0); - hashValue ^= CacheDataHandlerInstance().Hash(modelData); + hashValue ^= Hash(modelData); } } } } } } - if (hashValue != 0) + if (dataCacheHandle.size() == 1 && hashValue != 0) { - CacheDataHandlerInstance().Register(token, hashValue, dataCacheData.size()); + std::vector theHashValue(sizeof(hashValue)); + ::memcpy(theHashValue.data(), &hashValue, sizeof(hashValue)); + + write(*dataCacheHandle[0], theHashValue.data(), theHashValue.size()); + pwrite(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size(), theHashValue.size()); } bool executeWithDummyInputs = (std::find(options.GetBackends().begin(), @@ -374,13 +359,30 @@ GeneralResult ArmnnDriverImpl::PrepareArmnnModelFromCache( } // Validate dataCacheHandle - auto dataSize = CacheDataHandlerInstance().GetCacheSize(token); - if (!ValidateDataCacheHandle(dataCacheHandle, dataSize)) + if (dataCacheHandle.size() != 1) { return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "ArmnnDriverImpl::prepareModelFromCache(): Not valid data cache handle!"; } + if (!ValidateSharedHandle(dataCacheHandle[0])) + { + return NN_ERROR(ErrorStatus::GENERAL_FAILURE) + << "ArmnnDriverImpl::prepareModelFromCache(): Not valid data cache handle!"; + } + + size_t cachedDataSize = 0; + struct stat dataStatBuffer; + if (fstat(*dataCacheHandle[0], &dataStatBuffer) == 0) + { + cachedDataSize = dataStatBuffer.st_size; + } + if (cachedDataSize == 0) + { + return NN_ERROR(ErrorStatus::GENERAL_FAILURE) + << "ArmnnDriverImpl::prepareModelFromCache(): Not valid cached data!"; + } + // Check if model files cached they match the expected value unsigned int numberOfCachedModelFiles = 0; for (auto& backend : options.GetBackends()) @@ -393,10 +395,14 @@ GeneralResult ArmnnDriverImpl::PrepareArmnnModelFromCache( << "ArmnnDriverImpl::prepareModelFromCache(): Model cache handle size does not match."; } + // Read the hashValue + std::vector hashValue(sizeof(size_t)); + pread(*dataCacheHandle[0], hashValue.data(), hashValue.size(), 0); + // Read the model - std::vector dataCacheData(dataSize); - pread(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size(), 0); - auto hashValue = CacheDataHandlerInstance().Hash(dataCacheData); + std::vector dataCacheData(cachedDataSize - hashValue.size()); + pread(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size(), hashValue.size()); + auto calculatedHashValue = Hash(dataCacheData); int gpuAccCachedFd = -1; if (modelCacheHandle.size() > 0) @@ -423,7 +429,7 @@ GeneralResult ArmnnDriverImpl::PrepareArmnnModelFromCache( { std::vector modelData(modelDataSize); pread(cachedFd, modelData.data(), modelData.size(), 0); - hashValue ^= CacheDataHandlerInstance().Hash(modelData); + calculatedHashValue ^= Hash(modelData); if (backend == armnn::Compute::GpuAcc) { @@ -436,7 +442,9 @@ GeneralResult ArmnnDriverImpl::PrepareArmnnModelFromCache( } } - if (!CacheDataHandlerInstance().Validate(token, hashValue, dataCacheData.size())) + std::vector calculatedHashData(sizeof(calculatedHashValue)); + ::memcpy(calculatedHashData.data(), &calculatedHashValue, sizeof(calculatedHashValue)); + if (hashValue != calculatedHashData) { return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "ArmnnDriverImpl::prepareModelFromCache(): ValidateHash() failed!"; @@ -529,12 +537,13 @@ GeneralResult ArmnnDriverImpl::PrepareArmnnModelFromCache( return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << message.str(); } - return std::make_shared(netId, + auto preparedModel = std::make_shared(netId, runtime.get(), options.GetRequestInputsAndOutputsDumpDir(), options.IsGpuProfilingEnabled(), Priority::MEDIUM, true); + return std::move(preparedModel); } const Capabilities& ArmnnDriverImpl::GetCapabilities(const armnn::IRuntimePtr& runtime) diff --git a/shim/sl/canonical/CacheDataHandler.cpp b/shim/sl/canonical/CacheDataHandler.cpp deleted file mode 100644 index 930a8e4264..0000000000 --- a/shim/sl/canonical/CacheDataHandler.cpp +++ /dev/null @@ -1,69 +0,0 @@ -// -// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "CacheDataHandler.hpp" - -#include - -namespace armnn_driver -{ - -CacheDataHandler& CacheDataHandlerInstance() -{ - static CacheDataHandler instance; - return instance; -} - -void CacheDataHandler::Register(const android::nn::CacheToken token, const size_t hashValue, const size_t cacheSize) -{ - if (!m_CacheDataMap.empty() - && m_CacheDataMap.find(hashValue) != m_CacheDataMap.end() - && m_CacheDataMap.at(hashValue).GetToken() == token - && m_CacheDataMap.at(hashValue).GetCacheSize() == cacheSize) - { - return; - } - CacheHandle cacheHandle(token, cacheSize); - m_CacheDataMap.insert({hashValue, cacheHandle}); -} - -bool CacheDataHandler::Validate(const android::nn::CacheToken token, - const size_t hashValue, - const size_t cacheSize) const -{ - return (!m_CacheDataMap.empty() - && m_CacheDataMap.find(hashValue) != m_CacheDataMap.end() - && m_CacheDataMap.at(hashValue).GetToken() == token - && m_CacheDataMap.at(hashValue).GetCacheSize() == cacheSize); -} - -size_t CacheDataHandler::Hash(std::vector& cacheData) -{ - std::size_t hash = cacheData.size(); - for (auto& i : cacheData) - { - hash = ((hash << 5) - hash) + i; - } - return hash; -} - -size_t CacheDataHandler::GetCacheSize(android::nn::CacheToken token) -{ - for (auto i = m_CacheDataMap.begin(); i != m_CacheDataMap.end(); ++i) - { - if (i->second.GetToken() == token) - { - return i->second.GetCacheSize(); - } - } - return 0; -} - -void CacheDataHandler::Clear() -{ - m_CacheDataMap.clear(); -} - -} // armnn_driver diff --git a/shim/sl/canonical/CacheDataHandler.hpp b/shim/sl/canonical/CacheDataHandler.hpp deleted file mode 100644 index 95464a9809..0000000000 --- a/shim/sl/canonical/CacheDataHandler.hpp +++ /dev/null @@ -1,64 +0,0 @@ -// -// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include -#include - -#include - -namespace armnn_driver -{ - -class CacheHandle -{ -public: - CacheHandle(const android::nn::CacheToken token, const size_t cacheSize) - : m_CacheToken(token), m_CacheSize(cacheSize) {} - - ~CacheHandle() {}; - - android::nn::CacheToken GetToken() const - { - return m_CacheToken; - } - - size_t GetCacheSize() const - { - return m_CacheSize; - } - -private: - const android::nn::CacheToken m_CacheToken; - const size_t m_CacheSize; -}; - -class CacheDataHandler -{ -public: - CacheDataHandler() {} - ~CacheDataHandler() {} - - void Register(const android::nn::CacheToken token, const size_t hashValue, const size_t cacheSize); - - bool Validate(const android::nn::CacheToken token, const size_t hashValue, const size_t cacheSize) const; - - size_t Hash(std::vector& cacheData); - - size_t GetCacheSize(android::nn::CacheToken token); - - void Clear(); - -private: - CacheDataHandler(const CacheDataHandler&) = delete; - CacheDataHandler& operator=(const CacheDataHandler&) = delete; - - std::unordered_map m_CacheDataMap; -}; - -CacheDataHandler& CacheDataHandlerInstance(); - -} // armnn_driver diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index 90558bbd53..a405cb92a5 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -1396,6 +1396,7 @@ void IDeserializer::DeserializerImpl::ParseConstant(GraphPtr graph, unsigned int weightsShape[0], weightsShape[1], weightsShape[2]*weightsShape[3]}); + weightsInfo.SetConstant(true); armnn::ConstTensor weightsPermuted(weightsInfo, permuteBuffer.get()); @@ -1412,6 +1413,7 @@ void IDeserializer::DeserializerImpl::ParseConstant(GraphPtr graph, unsigned int layer = m_Network->AddConstantLayer(input, layerName.c_str()); armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); + outputTensorInfo.SetConstant(true); layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); } -- cgit v1.2.1