diff options
Diffstat (limited to 'shim/sl/canonical/ArmnnDriverImpl.cpp')
-rw-r--r-- | shim/sl/canonical/ArmnnDriverImpl.cpp | 91 |
1 files changed, 50 insertions, 41 deletions
diff --git a/shim/sl/canonical/ArmnnDriverImpl.cpp b/shim/sl/canonical/ArmnnDriverImpl.cpp index 8706c382b0..0c98a16138 100644 --- a/shim/sl/canonical/ArmnnDriverImpl.cpp +++ b/shim/sl/canonical/ArmnnDriverImpl.cpp @@ -5,7 +5,6 @@ #include "ArmnnDriverImpl.hpp" #include "ArmnnPreparedModel.hpp" -#include "CacheDataHandler.hpp" #include "ModelToINetworkTransformer.hpp" #include "SystemPropertiesUtils.hpp" @@ -62,6 +61,16 @@ Capabilities GenerateCapabilities() /* whilePerformance */ defaultPerfInfo }; } +size_t Hash(std::vector<uint8_t>& cacheData) +{ + std::size_t hash = cacheData.size(); + for (auto& i : cacheData) + { + hash = ((hash << 5) - hash) + i; + } + return hash; +} + } // anonymous namespace using namespace android::nn; @@ -87,33 +96,6 @@ bool ArmnnDriverImpl::ValidateSharedHandle(const SharedHandle& sharedHandle) return valid; } -bool ArmnnDriverImpl::ValidateDataCacheHandle(const std::vector<SharedHandle>& dataCacheHandle, const size_t dataSize) -{ - bool valid = true; - // DataCacheHandle size should always be 1 for ArmNN model - if (dataCacheHandle.size() != 1) - { - return !valid; - } - - if (dataSize == 0) - { - return !valid; - } - - struct stat statBuffer; - if (fstat(*dataCacheHandle[0], &statBuffer) == 0) - { - unsigned long bufferSize = statBuffer.st_size; - if (bufferSize != dataSize) - { - return !valid; - } - } - - return ValidateSharedHandle(dataCacheHandle[0]); -} - GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModel( const armnn::IRuntimePtr& runtime, const armnn::IGpuAccTunedParametersPtr& clTunedParameters, @@ -274,8 +256,7 @@ GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModel( size_t hashValue = 0; if (dataCacheHandle.size() == 1 ) { - write(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size()); - hashValue = CacheDataHandlerInstance().Hash(dataCacheData); + hashValue = Hash(dataCacheData); } // Cache the model data @@ -296,16 +277,20 @@ GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModel( { std::vector<uint8_t> modelData(modelDataSize); pread(*modelCacheHandle[i], modelData.data(), modelData.size(), 0); - hashValue ^= CacheDataHandlerInstance().Hash(modelData); + hashValue ^= Hash(modelData); } } } } } } - if (hashValue != 0) + if (dataCacheHandle.size() == 1 && hashValue != 0) { - CacheDataHandlerInstance().Register(token, hashValue, dataCacheData.size()); + std::vector<uint8_t> theHashValue(sizeof(hashValue)); + ::memcpy(theHashValue.data(), &hashValue, sizeof(hashValue)); + + write(*dataCacheHandle[0], theHashValue.data(), theHashValue.size()); + pwrite(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size(), theHashValue.size()); } bool executeWithDummyInputs = (std::find(options.GetBackends().begin(), @@ -374,13 +359,30 @@ GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModelFromCache( } // Validate dataCacheHandle - auto dataSize = CacheDataHandlerInstance().GetCacheSize(token); - if (!ValidateDataCacheHandle(dataCacheHandle, dataSize)) + if (dataCacheHandle.size() != 1) { return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "ArmnnDriverImpl::prepareModelFromCache(): Not valid data cache handle!"; } + if (!ValidateSharedHandle(dataCacheHandle[0])) + { + return NN_ERROR(ErrorStatus::GENERAL_FAILURE) + << "ArmnnDriverImpl::prepareModelFromCache(): Not valid data cache handle!"; + } + + size_t cachedDataSize = 0; + struct stat dataStatBuffer; + if (fstat(*dataCacheHandle[0], &dataStatBuffer) == 0) + { + cachedDataSize = dataStatBuffer.st_size; + } + if (cachedDataSize == 0) + { + return NN_ERROR(ErrorStatus::GENERAL_FAILURE) + << "ArmnnDriverImpl::prepareModelFromCache(): Not valid cached data!"; + } + // Check if model files cached they match the expected value unsigned int numberOfCachedModelFiles = 0; for (auto& backend : options.GetBackends()) @@ -393,10 +395,14 @@ GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModelFromCache( << "ArmnnDriverImpl::prepareModelFromCache(): Model cache handle size does not match."; } + // Read the hashValue + std::vector<uint8_t> hashValue(sizeof(size_t)); + pread(*dataCacheHandle[0], hashValue.data(), hashValue.size(), 0); + // Read the model - std::vector<uint8_t> dataCacheData(dataSize); - pread(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size(), 0); - auto hashValue = CacheDataHandlerInstance().Hash(dataCacheData); + std::vector<uint8_t> dataCacheData(cachedDataSize - hashValue.size()); + pread(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size(), hashValue.size()); + auto calculatedHashValue = Hash(dataCacheData); int gpuAccCachedFd = -1; if (modelCacheHandle.size() > 0) @@ -423,7 +429,7 @@ GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModelFromCache( { std::vector<uint8_t> modelData(modelDataSize); pread(cachedFd, modelData.data(), modelData.size(), 0); - hashValue ^= CacheDataHandlerInstance().Hash(modelData); + calculatedHashValue ^= Hash(modelData); if (backend == armnn::Compute::GpuAcc) { @@ -436,7 +442,9 @@ GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModelFromCache( } } - if (!CacheDataHandlerInstance().Validate(token, hashValue, dataCacheData.size())) + std::vector<uint8_t> calculatedHashData(sizeof(calculatedHashValue)); + ::memcpy(calculatedHashData.data(), &calculatedHashValue, sizeof(calculatedHashValue)); + if (hashValue != calculatedHashData) { return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "ArmnnDriverImpl::prepareModelFromCache(): ValidateHash() failed!"; @@ -529,12 +537,13 @@ GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModelFromCache( return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << message.str(); } - return std::make_shared<const ArmnnPreparedModel>(netId, + auto preparedModel = std::make_shared<const ArmnnPreparedModel>(netId, runtime.get(), options.GetRequestInputsAndOutputsDumpDir(), options.IsGpuProfilingEnabled(), Priority::MEDIUM, true); + return std::move(preparedModel); } const Capabilities& ArmnnDriverImpl::GetCapabilities(const armnn::IRuntimePtr& runtime) |