aboutsummaryrefslogtreecommitdiff
path: root/1.2
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2021-11-05 14:41:52 +0000
committerSadik Armagan <sadik.armagan@arm.com>2021-11-08 17:22:50 +0000
commitee6818be7815e10be4535645f0472ae5ad116309 (patch)
tree8880c9d0d8832f147afe05742716362de40a34c2 /1.2
parente27d4e89a34b07628b9a3de89706ca2558e9ee8e (diff)
downloadandroid-nn-driver-ee6818be7815e10be4535645f0472ae5ad116309.tar.gz
IVGCVSW-5636 'Implement NNAPI caching functions'
* Fixed test failures. !armnn:6617 Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Signed-off-by: Kevin May <kevin.may@arm.com> Change-Id: I9989ece8999d67dd40dfcf69b73f4d80f71687a4
Diffstat (limited to '1.2')
-rw-r--r--1.2/ArmnnDriverImpl.cpp56
1 files changed, 40 insertions, 16 deletions
diff --git a/1.2/ArmnnDriverImpl.cpp b/1.2/ArmnnDriverImpl.cpp
index b3bc5cd1..3274a8ab 100644
--- a/1.2/ArmnnDriverImpl.cpp
+++ b/1.2/ArmnnDriverImpl.cpp
@@ -315,6 +315,14 @@ Return<V1_0::ErrorStatus> ArmnnDriverImpl::prepareArmnnModel_1_2(
NotifyCallbackAndCheck(cb, V1_0::ErrorStatus::NONE, preparedModel.release());
return V1_0::ErrorStatus::NONE;
}
+
+ if (dataCacheHandle[0]->data[0] < 0)
+ {
+ ALOGW("ArmnnDriverImpl::prepareArmnnModel_1_3: Cannot cache the data, fd < 0");
+ NotifyCallbackAndCheck(cb, V1_0::ErrorStatus::NONE, preparedModel.release());
+ return V1_0::ErrorStatus::NONE;
+ }
+
int dataCacheFileAccessMode = fcntl(dataCacheHandle[0]->data[0], F_GETFL) & O_ACCMODE;
if (dataCacheFileAccessMode != O_RDWR)
{
@@ -420,6 +428,13 @@ Return<V1_0::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache(
return V1_0::ErrorStatus::GENERAL_FAILURE;
}
+ if (dataCacheHandle[0]->data[0] < 0)
+ {
+ ALOGW("ArmnnDriverImpl::prepareModelFromCache: Cannot read from the cache data, fd < 0");
+ FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "No data cache!", cb);
+ return V1_0::ErrorStatus::GENERAL_FAILURE;
+ }
+
int dataCacheFileAccessMode = fcntl(dataCacheHandle[0]->data[0], F_GETFL) & O_ACCMODE;
if (dataCacheFileAccessMode != O_RDWR)
{
@@ -441,16 +456,12 @@ Return<V1_0::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache(
if (fstat(dataCacheHandle[0]->data[0], &statBuffer) == 0)
{
unsigned long bufferSize = statBuffer.st_size;
- if (bufferSize <= 0)
+ if (bufferSize != dataSize)
{
ALOGW("ArmnnDriverImpl::prepareModelFromCache: Invalid data to deserialize!");
FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Invalid data to deserialize!", cb);
return V1_0::ErrorStatus::GENERAL_FAILURE;
}
- if (bufferSize > dataSize)
- {
- offset = bufferSize - dataSize;
- }
}
}
std::vector<uint8_t> dataCacheData(dataSize);
@@ -489,17 +500,19 @@ Return<V1_0::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache(
if (cachedFd != -1 && fstat(cachedFd, &statBuffer) == 0)
{
long modelDataSize = statBuffer.st_size;
- if (modelDataSize > 0)
+ if (modelDataSize <= 0)
{
- std::vector<uint8_t> modelData(modelDataSize);
- pread(cachedFd, modelData.data(), modelData.size(), 0);
- hashValue ^= CacheDataHandlerInstance().Hash(modelData);
+ FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Wrong cached model size!", cb);
+ return V1_0::ErrorStatus::NONE;
+ }
+ std::vector<uint8_t> modelData(modelDataSize);
+ pread(cachedFd, modelData.data(), modelData.size(), 0);
+ hashValue ^= CacheDataHandlerInstance().Hash(modelData);
- // For GpuAcc numberOfCachedFiles is 1
- if (backend == armnn::Compute::GpuAcc)
- {
- gpuAccCachedFd = cachedFd;
- }
+ // For GpuAcc numberOfCachedFiles is 1
+ if (backend == armnn::Compute::GpuAcc)
+ {
+ gpuAccCachedFd = cachedFd;
}
}
index += numberOfCacheFiles;
@@ -507,7 +520,7 @@ Return<V1_0::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache(
}
}
- if (!CacheDataHandlerInstance().Validate(token, hashValue))
+ if (!CacheDataHandlerInstance().Validate(token, hashValue, dataCacheData.size()))
{
ALOGW("ArmnnDriverImpl::prepareModelFromCache: ValidateHash() failed!");
FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "ValidateHash Failed!", cb);
@@ -515,7 +528,18 @@ Return<V1_0::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache(
}
// Deserialize the network..
- auto network = armnnDeserializer::IDeserializer::Create()->CreateNetworkFromBinary(dataCacheData);
+ armnn::INetworkPtr network = armnn::INetworkPtr(nullptr, [](armnn::INetwork*){});
+ try
+ {
+ network = armnnDeserializer::IDeserializer::Create()->CreateNetworkFromBinary(dataCacheData);
+ }
+ catch (std::exception& e)
+ {
+ std::stringstream message;
+ message << "Exception (" << e.what() << ") caught from Deserializer.";
+ FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
+ return V1_0::ErrorStatus::GENERAL_FAILURE;
+ }
// Optimize the network
armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);