From ee6818be7815e10be4535645f0472ae5ad116309 Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Fri, 5 Nov 2021 14:41:52 +0000 Subject: IVGCVSW-5636 'Implement NNAPI caching functions' * Fixed test failures. !armnn:6617 Signed-off-by: Sadik Armagan Signed-off-by: Kevin May Change-Id: I9989ece8999d67dd40dfcf69b73f4d80f71687a4 --- 1.2/ArmnnDriverImpl.cpp | 56 +++++++++++++++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 16 deletions(-) (limited to '1.2') diff --git a/1.2/ArmnnDriverImpl.cpp b/1.2/ArmnnDriverImpl.cpp index b3bc5cd1..3274a8ab 100644 --- a/1.2/ArmnnDriverImpl.cpp +++ b/1.2/ArmnnDriverImpl.cpp @@ -315,6 +315,14 @@ Return ArmnnDriverImpl::prepareArmnnModel_1_2( NotifyCallbackAndCheck(cb, V1_0::ErrorStatus::NONE, preparedModel.release()); return V1_0::ErrorStatus::NONE; } + + if (dataCacheHandle[0]->data[0] < 0) + { + ALOGW("ArmnnDriverImpl::prepareArmnnModel_1_3: Cannot cache the data, fd < 0"); + NotifyCallbackAndCheck(cb, V1_0::ErrorStatus::NONE, preparedModel.release()); + return V1_0::ErrorStatus::NONE; + } + int dataCacheFileAccessMode = fcntl(dataCacheHandle[0]->data[0], F_GETFL) & O_ACCMODE; if (dataCacheFileAccessMode != O_RDWR) { @@ -420,6 +428,13 @@ Return ArmnnDriverImpl::prepareModelFromCache( return V1_0::ErrorStatus::GENERAL_FAILURE; } + if (dataCacheHandle[0]->data[0] < 0) + { + ALOGW("ArmnnDriverImpl::prepareModelFromCache: Cannot read from the cache data, fd < 0"); + FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "No data cache!", cb); + return V1_0::ErrorStatus::GENERAL_FAILURE; + } + int dataCacheFileAccessMode = fcntl(dataCacheHandle[0]->data[0], F_GETFL) & O_ACCMODE; if (dataCacheFileAccessMode != O_RDWR) { @@ -441,16 +456,12 @@ Return ArmnnDriverImpl::prepareModelFromCache( if (fstat(dataCacheHandle[0]->data[0], &statBuffer) == 0) { unsigned long bufferSize = statBuffer.st_size; - if (bufferSize <= 0) + if (bufferSize != dataSize) { ALOGW("ArmnnDriverImpl::prepareModelFromCache: Invalid data to deserialize!"); FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Invalid data to deserialize!", cb); return V1_0::ErrorStatus::GENERAL_FAILURE; } - if (bufferSize > dataSize) - { - offset = bufferSize - dataSize; - } } } std::vector dataCacheData(dataSize); @@ -489,17 +500,19 @@ Return ArmnnDriverImpl::prepareModelFromCache( if (cachedFd != -1 && fstat(cachedFd, &statBuffer) == 0) { long modelDataSize = statBuffer.st_size; - if (modelDataSize > 0) + if (modelDataSize <= 0) { - std::vector modelData(modelDataSize); - pread(cachedFd, modelData.data(), modelData.size(), 0); - hashValue ^= CacheDataHandlerInstance().Hash(modelData); + FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Wrong cached model size!", cb); + return V1_0::ErrorStatus::NONE; + } + std::vector modelData(modelDataSize); + pread(cachedFd, modelData.data(), modelData.size(), 0); + hashValue ^= CacheDataHandlerInstance().Hash(modelData); - // For GpuAcc numberOfCachedFiles is 1 - if (backend == armnn::Compute::GpuAcc) - { - gpuAccCachedFd = cachedFd; - } + // For GpuAcc numberOfCachedFiles is 1 + if (backend == armnn::Compute::GpuAcc) + { + gpuAccCachedFd = cachedFd; } } index += numberOfCacheFiles; @@ -507,7 +520,7 @@ Return ArmnnDriverImpl::prepareModelFromCache( } } - if (!CacheDataHandlerInstance().Validate(token, hashValue)) + if (!CacheDataHandlerInstance().Validate(token, hashValue, dataCacheData.size())) { ALOGW("ArmnnDriverImpl::prepareModelFromCache: ValidateHash() failed!"); FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "ValidateHash Failed!", cb); @@ -515,7 +528,18 @@ Return ArmnnDriverImpl::prepareModelFromCache( } // Deserialize the network.. - auto network = armnnDeserializer::IDeserializer::Create()->CreateNetworkFromBinary(dataCacheData); + armnn::INetworkPtr network = armnn::INetworkPtr(nullptr, [](armnn::INetwork*){}); + try + { + network = armnnDeserializer::IDeserializer::Create()->CreateNetworkFromBinary(dataCacheData); + } + catch (std::exception& e) + { + std::stringstream message; + message << "Exception (" << e.what() << ") caught from Deserializer."; + FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb); + return V1_0::ErrorStatus::GENERAL_FAILURE; + } // Optimize the network armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr); -- cgit v1.2.1