diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2021-11-05 14:41:52 +0000 |
---|---|---|
committer | Sadik Armagan <sadik.armagan@arm.com> | 2021-11-08 17:22:50 +0000 |
commit | ee6818be7815e10be4535645f0472ae5ad116309 (patch) | |
tree | 8880c9d0d8832f147afe05742716362de40a34c2 /1.3 | |
parent | e27d4e89a34b07628b9a3de89706ca2558e9ee8e (diff) | |
download | android-nn-driver-ee6818be7815e10be4535645f0472ae5ad116309.tar.gz |
IVGCVSW-5636 'Implement NNAPI caching functions'
* Fixed test failures.
!armnn:6617
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Signed-off-by: Kevin May <kevin.may@arm.com>
Change-Id: I9989ece8999d67dd40dfcf69b73f4d80f71687a4
Diffstat (limited to '1.3')
-rw-r--r-- | 1.3/ArmnnDriverImpl.cpp | 56 |
1 files changed, 40 insertions, 16 deletions
diff --git a/1.3/ArmnnDriverImpl.cpp b/1.3/ArmnnDriverImpl.cpp index e1d65f92..c8b1d968 100644 --- a/1.3/ArmnnDriverImpl.cpp +++ b/1.3/ArmnnDriverImpl.cpp @@ -328,6 +328,14 @@ Return<V1_3::ErrorStatus> ArmnnDriverImpl::prepareArmnnModel_1_3( NotifyCallbackAndCheck(cb, V1_3::ErrorStatus::NONE, preparedModel.release()); return V1_3::ErrorStatus::NONE; } + + if (dataCacheHandle[0]->data[0] < 0) + { + ALOGW("ArmnnDriverImpl::prepareArmnnModel_1_3: Cannot cache the data, fd < 0"); + NotifyCallbackAndCheck(cb, V1_3::ErrorStatus::NONE, preparedModel.release()); + return V1_3::ErrorStatus::NONE; + } + int dataCacheFileAccessMode = fcntl(dataCacheHandle[0]->data[0], F_GETFL) & O_ACCMODE; if (dataCacheFileAccessMode != O_RDWR) { @@ -435,6 +443,13 @@ Return<V1_3::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache_1_3( return V1_3::ErrorStatus::GENERAL_FAILURE; } + if (dataCacheHandle[0]->data[0] < 0) + { + ALOGW("ArmnnDriverImpl::prepareModelFromCache_1_3(): Cannot read from the cache data, fd < 0"); + cb->notify_1_3(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr); + return V1_3::ErrorStatus::GENERAL_FAILURE; + } + int dataCacheFileAccessMode = fcntl(dataCacheHandle[0]->data[0], F_GETFL) & O_ACCMODE; if (dataCacheFileAccessMode != O_RDWR) { @@ -456,16 +471,12 @@ Return<V1_3::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache_1_3( if (fstat(dataCacheHandle[0]->data[0], &statBuffer) == 0) { unsigned long bufferSize = statBuffer.st_size; - if (bufferSize <= 0) + if (bufferSize != dataSize) { ALOGW("ArmnnDriverImpl::prepareModelFromCache_1_3: Invalid data to deserialize!"); cb->notify_1_3(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr); return V1_3::ErrorStatus::GENERAL_FAILURE; } - if (bufferSize > dataSize) - { - offset = bufferSize - dataSize; - } } } std::vector<uint8_t> dataCacheData(dataSize); @@ -504,17 +515,20 @@ Return<V1_3::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache_1_3( if (cachedFd != -1 && fstat(cachedFd, &statBuffer) == 0) { long modelDataSize = statBuffer.st_size; - if (modelDataSize > 0) + if (modelDataSize <= 0) { - std::vector<uint8_t> modelData(modelDataSize); - pread(cachedFd, modelData.data(), modelData.size(), 0); - hashValue ^= CacheDataHandlerInstance().Hash(modelData); + ALOGW("ArmnnDriverImpl::prepareModelFromCache_1_3(): Wrong cached model size!"); + cb->notify_1_3(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr); + return V1_3::ErrorStatus::NONE; + } + std::vector<uint8_t> modelData(modelDataSize); + pread(cachedFd, modelData.data(), modelData.size(), 0); + hashValue ^= CacheDataHandlerInstance().Hash(modelData); - // For GpuAcc numberOfCachedFiles is 1 - if (backend == armnn::Compute::GpuAcc) - { - gpuAccCachedFd = cachedFd; - } + // For GpuAcc numberOfCachedFiles is 1 + if (backend == armnn::Compute::GpuAcc) + { + gpuAccCachedFd = cachedFd; } } index += numberOfCacheFiles; @@ -522,7 +536,7 @@ Return<V1_3::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache_1_3( } } - if (!CacheDataHandlerInstance().Validate(token, hashValue)) + if (!CacheDataHandlerInstance().Validate(token, hashValue, dataCacheData.size())) { ALOGW("ArmnnDriverImpl::prepareModelFromCache_1_3: ValidateHash() failed!"); cb->notify_1_3(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr); @@ -530,7 +544,17 @@ Return<V1_3::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache_1_3( } // Deserialize the network.. - auto network = armnnDeserializer::IDeserializer::Create()->CreateNetworkFromBinary(dataCacheData); + armnn::INetworkPtr network = armnn::INetworkPtr(nullptr, [](armnn::INetwork*){}); + try + { + network = armnnDeserializer::IDeserializer::Create()->CreateNetworkFromBinary(dataCacheData); + } + catch (std::exception&) + { + ALOGW("ArmnnDriverImpl::prepareModelFromCache_1_3: Exception caught from Deserializer!"); + cb->notify_1_3(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr); + return V1_3::ErrorStatus::GENERAL_FAILURE; + } // Optimize the network armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr); |