aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2021-11-05 14:41:52 +0000
committerSadik Armagan <sadik.armagan@arm.com>2021-11-08 17:22:50 +0000
commitee6818be7815e10be4535645f0472ae5ad116309 (patch)
tree8880c9d0d8832f147afe05742716362de40a34c2
parente27d4e89a34b07628b9a3de89706ca2558e9ee8e (diff)
downloadandroid-nn-driver-ee6818be7815e10be4535645f0472ae5ad116309.tar.gz
IVGCVSW-5636 'Implement NNAPI caching functions'
* Fixed test failures. !armnn:6617 Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Signed-off-by: Kevin May <kevin.may@arm.com> Change-Id: I9989ece8999d67dd40dfcf69b73f4d80f71687a4
-rw-r--r--1.2/ArmnnDriverImpl.cpp56
-rw-r--r--1.3/ArmnnDriverImpl.cpp56
-rw-r--r--CacheDataHandler.cpp13
-rw-r--r--CacheDataHandler.hpp2
4 files changed, 89 insertions, 38 deletions
diff --git a/1.2/ArmnnDriverImpl.cpp b/1.2/ArmnnDriverImpl.cpp
index b3bc5cd1..3274a8ab 100644
--- a/1.2/ArmnnDriverImpl.cpp
+++ b/1.2/ArmnnDriverImpl.cpp
@@ -315,6 +315,14 @@ Return<V1_0::ErrorStatus> ArmnnDriverImpl::prepareArmnnModel_1_2(
NotifyCallbackAndCheck(cb, V1_0::ErrorStatus::NONE, preparedModel.release());
return V1_0::ErrorStatus::NONE;
}
+
+ if (dataCacheHandle[0]->data[0] < 0)
+ {
+ ALOGW("ArmnnDriverImpl::prepareArmnnModel_1_3: Cannot cache the data, fd < 0");
+ NotifyCallbackAndCheck(cb, V1_0::ErrorStatus::NONE, preparedModel.release());
+ return V1_0::ErrorStatus::NONE;
+ }
+
int dataCacheFileAccessMode = fcntl(dataCacheHandle[0]->data[0], F_GETFL) & O_ACCMODE;
if (dataCacheFileAccessMode != O_RDWR)
{
@@ -420,6 +428,13 @@ Return<V1_0::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache(
return V1_0::ErrorStatus::GENERAL_FAILURE;
}
+ if (dataCacheHandle[0]->data[0] < 0)
+ {
+ ALOGW("ArmnnDriverImpl::prepareModelFromCache: Cannot read from the cache data, fd < 0");
+ FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "No data cache!", cb);
+ return V1_0::ErrorStatus::GENERAL_FAILURE;
+ }
+
int dataCacheFileAccessMode = fcntl(dataCacheHandle[0]->data[0], F_GETFL) & O_ACCMODE;
if (dataCacheFileAccessMode != O_RDWR)
{
@@ -441,16 +456,12 @@ Return<V1_0::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache(
if (fstat(dataCacheHandle[0]->data[0], &statBuffer) == 0)
{
unsigned long bufferSize = statBuffer.st_size;
- if (bufferSize <= 0)
+ if (bufferSize != dataSize)
{
ALOGW("ArmnnDriverImpl::prepareModelFromCache: Invalid data to deserialize!");
FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Invalid data to deserialize!", cb);
return V1_0::ErrorStatus::GENERAL_FAILURE;
}
- if (bufferSize > dataSize)
- {
- offset = bufferSize - dataSize;
- }
}
}
std::vector<uint8_t> dataCacheData(dataSize);
@@ -489,17 +500,19 @@ Return<V1_0::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache(
if (cachedFd != -1 && fstat(cachedFd, &statBuffer) == 0)
{
long modelDataSize = statBuffer.st_size;
- if (modelDataSize > 0)
+ if (modelDataSize <= 0)
{
- std::vector<uint8_t> modelData(modelDataSize);
- pread(cachedFd, modelData.data(), modelData.size(), 0);
- hashValue ^= CacheDataHandlerInstance().Hash(modelData);
+ FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Wrong cached model size!", cb);
+ return V1_0::ErrorStatus::NONE;
+ }
+ std::vector<uint8_t> modelData(modelDataSize);
+ pread(cachedFd, modelData.data(), modelData.size(), 0);
+ hashValue ^= CacheDataHandlerInstance().Hash(modelData);
- // For GpuAcc numberOfCachedFiles is 1
- if (backend == armnn::Compute::GpuAcc)
- {
- gpuAccCachedFd = cachedFd;
- }
+ // For GpuAcc numberOfCachedFiles is 1
+ if (backend == armnn::Compute::GpuAcc)
+ {
+ gpuAccCachedFd = cachedFd;
}
}
index += numberOfCacheFiles;
@@ -507,7 +520,7 @@ Return<V1_0::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache(
}
}
- if (!CacheDataHandlerInstance().Validate(token, hashValue))
+ if (!CacheDataHandlerInstance().Validate(token, hashValue, dataCacheData.size()))
{
ALOGW("ArmnnDriverImpl::prepareModelFromCache: ValidateHash() failed!");
FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "ValidateHash Failed!", cb);
@@ -515,7 +528,18 @@ Return<V1_0::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache(
}
// Deserialize the network..
- auto network = armnnDeserializer::IDeserializer::Create()->CreateNetworkFromBinary(dataCacheData);
+ armnn::INetworkPtr network = armnn::INetworkPtr(nullptr, [](armnn::INetwork*){});
+ try
+ {
+ network = armnnDeserializer::IDeserializer::Create()->CreateNetworkFromBinary(dataCacheData);
+ }
+ catch (std::exception& e)
+ {
+ std::stringstream message;
+ message << "Exception (" << e.what() << ") caught from Deserializer.";
+ FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
+ return V1_0::ErrorStatus::GENERAL_FAILURE;
+ }
// Optimize the network
armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
diff --git a/1.3/ArmnnDriverImpl.cpp b/1.3/ArmnnDriverImpl.cpp
index e1d65f92..c8b1d968 100644
--- a/1.3/ArmnnDriverImpl.cpp
+++ b/1.3/ArmnnDriverImpl.cpp
@@ -328,6 +328,14 @@ Return<V1_3::ErrorStatus> ArmnnDriverImpl::prepareArmnnModel_1_3(
NotifyCallbackAndCheck(cb, V1_3::ErrorStatus::NONE, preparedModel.release());
return V1_3::ErrorStatus::NONE;
}
+
+ if (dataCacheHandle[0]->data[0] < 0)
+ {
+ ALOGW("ArmnnDriverImpl::prepareArmnnModel_1_3: Cannot cache the data, fd < 0");
+ NotifyCallbackAndCheck(cb, V1_3::ErrorStatus::NONE, preparedModel.release());
+ return V1_3::ErrorStatus::NONE;
+ }
+
int dataCacheFileAccessMode = fcntl(dataCacheHandle[0]->data[0], F_GETFL) & O_ACCMODE;
if (dataCacheFileAccessMode != O_RDWR)
{
@@ -435,6 +443,13 @@ Return<V1_3::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache_1_3(
return V1_3::ErrorStatus::GENERAL_FAILURE;
}
+ if (dataCacheHandle[0]->data[0] < 0)
+ {
+ ALOGW("ArmnnDriverImpl::prepareModelFromCache_1_3(): Cannot read from the cache data, fd < 0");
+ cb->notify_1_3(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
+ return V1_3::ErrorStatus::GENERAL_FAILURE;
+ }
+
int dataCacheFileAccessMode = fcntl(dataCacheHandle[0]->data[0], F_GETFL) & O_ACCMODE;
if (dataCacheFileAccessMode != O_RDWR)
{
@@ -456,16 +471,12 @@ Return<V1_3::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache_1_3(
if (fstat(dataCacheHandle[0]->data[0], &statBuffer) == 0)
{
unsigned long bufferSize = statBuffer.st_size;
- if (bufferSize <= 0)
+ if (bufferSize != dataSize)
{
ALOGW("ArmnnDriverImpl::prepareModelFromCache_1_3: Invalid data to deserialize!");
cb->notify_1_3(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
return V1_3::ErrorStatus::GENERAL_FAILURE;
}
- if (bufferSize > dataSize)
- {
- offset = bufferSize - dataSize;
- }
}
}
std::vector<uint8_t> dataCacheData(dataSize);
@@ -504,17 +515,20 @@ Return<V1_3::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache_1_3(
if (cachedFd != -1 && fstat(cachedFd, &statBuffer) == 0)
{
long modelDataSize = statBuffer.st_size;
- if (modelDataSize > 0)
+ if (modelDataSize <= 0)
{
- std::vector<uint8_t> modelData(modelDataSize);
- pread(cachedFd, modelData.data(), modelData.size(), 0);
- hashValue ^= CacheDataHandlerInstance().Hash(modelData);
+ ALOGW("ArmnnDriverImpl::prepareModelFromCache_1_3(): Wrong cached model size!");
+ cb->notify_1_3(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
+ return V1_3::ErrorStatus::NONE;
+ }
+ std::vector<uint8_t> modelData(modelDataSize);
+ pread(cachedFd, modelData.data(), modelData.size(), 0);
+ hashValue ^= CacheDataHandlerInstance().Hash(modelData);
- // For GpuAcc numberOfCachedFiles is 1
- if (backend == armnn::Compute::GpuAcc)
- {
- gpuAccCachedFd = cachedFd;
- }
+ // For GpuAcc numberOfCachedFiles is 1
+ if (backend == armnn::Compute::GpuAcc)
+ {
+ gpuAccCachedFd = cachedFd;
}
}
index += numberOfCacheFiles;
@@ -522,7 +536,7 @@ Return<V1_3::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache_1_3(
}
}
- if (!CacheDataHandlerInstance().Validate(token, hashValue))
+ if (!CacheDataHandlerInstance().Validate(token, hashValue, dataCacheData.size()))
{
ALOGW("ArmnnDriverImpl::prepareModelFromCache_1_3: ValidateHash() failed!");
cb->notify_1_3(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
@@ -530,7 +544,17 @@ Return<V1_3::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache_1_3(
}
// Deserialize the network..
- auto network = armnnDeserializer::IDeserializer::Create()->CreateNetworkFromBinary(dataCacheData);
+ armnn::INetworkPtr network = armnn::INetworkPtr(nullptr, [](armnn::INetwork*){});
+ try
+ {
+ network = armnnDeserializer::IDeserializer::Create()->CreateNetworkFromBinary(dataCacheData);
+ }
+ catch (std::exception&)
+ {
+ ALOGW("ArmnnDriverImpl::prepareModelFromCache_1_3: Exception caught from Deserializer!");
+ cb->notify_1_3(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
+ return V1_3::ErrorStatus::GENERAL_FAILURE;
+ }
// Optimize the network
armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
diff --git a/CacheDataHandler.cpp b/CacheDataHandler.cpp
index 36881629..5f3a3076 100644
--- a/CacheDataHandler.cpp
+++ b/CacheDataHandler.cpp
@@ -18,19 +18,22 @@ CacheDataHandler& CacheDataHandlerInstance()
void CacheDataHandler::Register(const HidlToken token, const size_t hashValue, const size_t cacheSize)
{
- if (m_CacheDataMap.find(hashValue) != m_CacheDataMap.end())
+ if (m_CacheDataMap.find(hashValue) != m_CacheDataMap.end()
+ && m_CacheDataMap.at(hashValue).GetToken() == token
+ && m_CacheDataMap.at(hashValue).GetCacheSize() == cacheSize)
{
- ALOGV("CacheHandler::Register() Token has been already registered.");
+ ALOGV("CacheHandler::Register() Hash value has already registered.");
return;
}
CacheHandle cacheHandle(token, cacheSize);
m_CacheDataMap.insert({hashValue, cacheHandle});
}
-bool CacheDataHandler::Validate(const HidlToken token, const size_t hashValue) const
+bool CacheDataHandler::Validate(const HidlToken token, const size_t hashValue, const size_t cacheSize) const
{
return (m_CacheDataMap.find(hashValue) != m_CacheDataMap.end()
- && m_CacheDataMap.at(hashValue).GetToken() == token);
+ && m_CacheDataMap.at(hashValue).GetToken() == token
+ && m_CacheDataMap.at(hashValue).GetCacheSize() == cacheSize);
}
size_t CacheDataHandler::Hash(std::vector<uint8_t>& cacheData)
@@ -38,7 +41,7 @@ size_t CacheDataHandler::Hash(std::vector<uint8_t>& cacheData)
std::size_t hash = cacheData.size();
for (auto& i : cacheData)
{
- hash ^= std::hash<unsigned int>{}(i);
+ hash = ((hash << 5) - hash) + i;
}
return hash;
}
diff --git a/CacheDataHandler.hpp b/CacheDataHandler.hpp
index cea73d20..5b1b2951 100644
--- a/CacheDataHandler.hpp
+++ b/CacheDataHandler.hpp
@@ -48,7 +48,7 @@ public:
void Register(const HidlToken token, const size_t hashValue, const size_t cacheSize);
- bool Validate(const HidlToken token, const size_t hashValue) const;
+ bool Validate(const HidlToken token, const size_t hashValue, const size_t cacheSize) const;
size_t Hash(std::vector<uint8_t>& cacheData);