aboutsummaryrefslogtreecommitdiff
path: root/1.3/ArmnnDriverImpl.cpp
diff options
context:
space:
mode:
Diffstat (limited to '1.3/ArmnnDriverImpl.cpp')
-rw-r--r--1.3/ArmnnDriverImpl.cpp56
1 files changed, 40 insertions, 16 deletions
diff --git a/1.3/ArmnnDriverImpl.cpp b/1.3/ArmnnDriverImpl.cpp
index e1d65f92..c8b1d968 100644
--- a/1.3/ArmnnDriverImpl.cpp
+++ b/1.3/ArmnnDriverImpl.cpp
@@ -328,6 +328,14 @@ Return<V1_3::ErrorStatus> ArmnnDriverImpl::prepareArmnnModel_1_3(
NotifyCallbackAndCheck(cb, V1_3::ErrorStatus::NONE, preparedModel.release());
return V1_3::ErrorStatus::NONE;
}
+
+ if (dataCacheHandle[0]->data[0] < 0)
+ {
+ ALOGW("ArmnnDriverImpl::prepareArmnnModel_1_3: Cannot cache the data, fd < 0");
+ NotifyCallbackAndCheck(cb, V1_3::ErrorStatus::NONE, preparedModel.release());
+ return V1_3::ErrorStatus::NONE;
+ }
+
int dataCacheFileAccessMode = fcntl(dataCacheHandle[0]->data[0], F_GETFL) & O_ACCMODE;
if (dataCacheFileAccessMode != O_RDWR)
{
@@ -435,6 +443,13 @@ Return<V1_3::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache_1_3(
return V1_3::ErrorStatus::GENERAL_FAILURE;
}
+ if (dataCacheHandle[0]->data[0] < 0)
+ {
+ ALOGW("ArmnnDriverImpl::prepareModelFromCache_1_3(): Cannot read from the cache data, fd < 0");
+ cb->notify_1_3(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
+ return V1_3::ErrorStatus::GENERAL_FAILURE;
+ }
+
int dataCacheFileAccessMode = fcntl(dataCacheHandle[0]->data[0], F_GETFL) & O_ACCMODE;
if (dataCacheFileAccessMode != O_RDWR)
{
@@ -456,16 +471,12 @@ Return<V1_3::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache_1_3(
if (fstat(dataCacheHandle[0]->data[0], &statBuffer) == 0)
{
unsigned long bufferSize = statBuffer.st_size;
- if (bufferSize <= 0)
+ if (bufferSize != dataSize)
{
ALOGW("ArmnnDriverImpl::prepareModelFromCache_1_3: Invalid data to deserialize!");
cb->notify_1_3(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
return V1_3::ErrorStatus::GENERAL_FAILURE;
}
- if (bufferSize > dataSize)
- {
- offset = bufferSize - dataSize;
- }
}
}
std::vector<uint8_t> dataCacheData(dataSize);
@@ -504,17 +515,20 @@ Return<V1_3::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache_1_3(
if (cachedFd != -1 && fstat(cachedFd, &statBuffer) == 0)
{
long modelDataSize = statBuffer.st_size;
- if (modelDataSize > 0)
+ if (modelDataSize <= 0)
{
- std::vector<uint8_t> modelData(modelDataSize);
- pread(cachedFd, modelData.data(), modelData.size(), 0);
- hashValue ^= CacheDataHandlerInstance().Hash(modelData);
+ ALOGW("ArmnnDriverImpl::prepareModelFromCache_1_3(): Wrong cached model size!");
+ cb->notify_1_3(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
+ return V1_3::ErrorStatus::NONE;
+ }
+ std::vector<uint8_t> modelData(modelDataSize);
+ pread(cachedFd, modelData.data(), modelData.size(), 0);
+ hashValue ^= CacheDataHandlerInstance().Hash(modelData);
- // For GpuAcc numberOfCachedFiles is 1
- if (backend == armnn::Compute::GpuAcc)
- {
- gpuAccCachedFd = cachedFd;
- }
+ // For GpuAcc numberOfCachedFiles is 1
+ if (backend == armnn::Compute::GpuAcc)
+ {
+ gpuAccCachedFd = cachedFd;
}
}
index += numberOfCacheFiles;
@@ -522,7 +536,7 @@ Return<V1_3::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache_1_3(
}
}
- if (!CacheDataHandlerInstance().Validate(token, hashValue))
+ if (!CacheDataHandlerInstance().Validate(token, hashValue, dataCacheData.size()))
{
ALOGW("ArmnnDriverImpl::prepareModelFromCache_1_3: ValidateHash() failed!");
cb->notify_1_3(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
@@ -530,7 +544,17 @@ Return<V1_3::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache_1_3(
}
// Deserialize the network..
- auto network = armnnDeserializer::IDeserializer::Create()->CreateNetworkFromBinary(dataCacheData);
+ armnn::INetworkPtr network = armnn::INetworkPtr(nullptr, [](armnn::INetwork*){});
+ try
+ {
+ network = armnnDeserializer::IDeserializer::Create()->CreateNetworkFromBinary(dataCacheData);
+ }
+ catch (std::exception&)
+ {
+ ALOGW("ArmnnDriverImpl::prepareModelFromCache_1_3: Exception caught from Deserializer!");
+ cb->notify_1_3(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
+ return V1_3::ErrorStatus::GENERAL_FAILURE;
+ }
// Optimize the network
armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);