aboutsummaryrefslogtreecommitdiff
path: root/1.2
diff options
context:
space:
mode:
Diffstat (limited to '1.2')
-rw-r--r--1.2/ArmnnDriverImpl.cpp56
1 files changed, 40 insertions, 16 deletions
diff --git a/1.2/ArmnnDriverImpl.cpp b/1.2/ArmnnDriverImpl.cpp
index b3bc5cd1..3274a8ab 100644
--- a/1.2/ArmnnDriverImpl.cpp
+++ b/1.2/ArmnnDriverImpl.cpp
@@ -315,6 +315,14 @@ Return<V1_0::ErrorStatus> ArmnnDriverImpl::prepareArmnnModel_1_2(
NotifyCallbackAndCheck(cb, V1_0::ErrorStatus::NONE, preparedModel.release());
return V1_0::ErrorStatus::NONE;
}
+
+ if (dataCacheHandle[0]->data[0] < 0)
+ {
+ ALOGW("ArmnnDriverImpl::prepareArmnnModel_1_3: Cannot cache the data, fd < 0");
+ NotifyCallbackAndCheck(cb, V1_0::ErrorStatus::NONE, preparedModel.release());
+ return V1_0::ErrorStatus::NONE;
+ }
+
int dataCacheFileAccessMode = fcntl(dataCacheHandle[0]->data[0], F_GETFL) & O_ACCMODE;
if (dataCacheFileAccessMode != O_RDWR)
{
@@ -420,6 +428,13 @@ Return<V1_0::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache(
return V1_0::ErrorStatus::GENERAL_FAILURE;
}
+ if (dataCacheHandle[0]->data[0] < 0)
+ {
+ ALOGW("ArmnnDriverImpl::prepareModelFromCache: Cannot read from the cache data, fd < 0");
+ FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "No data cache!", cb);
+ return V1_0::ErrorStatus::GENERAL_FAILURE;
+ }
+
int dataCacheFileAccessMode = fcntl(dataCacheHandle[0]->data[0], F_GETFL) & O_ACCMODE;
if (dataCacheFileAccessMode != O_RDWR)
{
@@ -441,16 +456,12 @@ Return<V1_0::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache(
if (fstat(dataCacheHandle[0]->data[0], &statBuffer) == 0)
{
unsigned long bufferSize = statBuffer.st_size;
- if (bufferSize <= 0)
+ if (bufferSize != dataSize)
{
ALOGW("ArmnnDriverImpl::prepareModelFromCache: Invalid data to deserialize!");
FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Invalid data to deserialize!", cb);
return V1_0::ErrorStatus::GENERAL_FAILURE;
}
- if (bufferSize > dataSize)
- {
- offset = bufferSize - dataSize;
- }
}
}
std::vector<uint8_t> dataCacheData(dataSize);
@@ -489,17 +500,19 @@ Return<V1_0::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache(
if (cachedFd != -1 && fstat(cachedFd, &statBuffer) == 0)
{
long modelDataSize = statBuffer.st_size;
- if (modelDataSize > 0)
+ if (modelDataSize <= 0)
{
- std::vector<uint8_t> modelData(modelDataSize);
- pread(cachedFd, modelData.data(), modelData.size(), 0);
- hashValue ^= CacheDataHandlerInstance().Hash(modelData);
+ FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Wrong cached model size!", cb);
+ return V1_0::ErrorStatus::NONE;
+ }
+ std::vector<uint8_t> modelData(modelDataSize);
+ pread(cachedFd, modelData.data(), modelData.size(), 0);
+ hashValue ^= CacheDataHandlerInstance().Hash(modelData);
- // For GpuAcc numberOfCachedFiles is 1
- if (backend == armnn::Compute::GpuAcc)
- {
- gpuAccCachedFd = cachedFd;
- }
+ // For GpuAcc numberOfCachedFiles is 1
+ if (backend == armnn::Compute::GpuAcc)
+ {
+ gpuAccCachedFd = cachedFd;
}
}
index += numberOfCacheFiles;
@@ -507,7 +520,7 @@ Return<V1_0::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache(
}
}
- if (!CacheDataHandlerInstance().Validate(token, hashValue))
+ if (!CacheDataHandlerInstance().Validate(token, hashValue, dataCacheData.size()))
{
ALOGW("ArmnnDriverImpl::prepareModelFromCache: ValidateHash() failed!");
FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "ValidateHash Failed!", cb);
@@ -515,7 +528,18 @@ Return<V1_0::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache(
}
// Deserialize the network..
- auto network = armnnDeserializer::IDeserializer::Create()->CreateNetworkFromBinary(dataCacheData);
+ armnn::INetworkPtr network = armnn::INetworkPtr(nullptr, [](armnn::INetwork*){});
+ try
+ {
+ network = armnnDeserializer::IDeserializer::Create()->CreateNetworkFromBinary(dataCacheData);
+ }
+ catch (std::exception& e)
+ {
+ std::stringstream message;
+ message << "Exception (" << e.what() << ") caught from Deserializer.";
+ FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
+ return V1_0::ErrorStatus::GENERAL_FAILURE;
+ }
// Optimize the network
armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);