aboutsummaryrefslogtreecommitdiff
path: root/shim/sl/canonical/ArmnnDriverImpl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'shim/sl/canonical/ArmnnDriverImpl.cpp')
-rw-r--r--shim/sl/canonical/ArmnnDriverImpl.cpp91
1 files changed, 50 insertions, 41 deletions
diff --git a/shim/sl/canonical/ArmnnDriverImpl.cpp b/shim/sl/canonical/ArmnnDriverImpl.cpp
index 8706c382b0..0c98a16138 100644
--- a/shim/sl/canonical/ArmnnDriverImpl.cpp
+++ b/shim/sl/canonical/ArmnnDriverImpl.cpp
@@ -5,7 +5,6 @@
#include "ArmnnDriverImpl.hpp"
#include "ArmnnPreparedModel.hpp"
-#include "CacheDataHandler.hpp"
#include "ModelToINetworkTransformer.hpp"
#include "SystemPropertiesUtils.hpp"
@@ -62,6 +61,16 @@ Capabilities GenerateCapabilities()
/* whilePerformance */ defaultPerfInfo };
}
+size_t Hash(std::vector<uint8_t>& cacheData)
+{
+ std::size_t hash = cacheData.size();
+ for (auto& i : cacheData)
+ {
+ hash = ((hash << 5) - hash) + i;
+ }
+ return hash;
+}
+
} // anonymous namespace
using namespace android::nn;
@@ -87,33 +96,6 @@ bool ArmnnDriverImpl::ValidateSharedHandle(const SharedHandle& sharedHandle)
return valid;
}
-bool ArmnnDriverImpl::ValidateDataCacheHandle(const std::vector<SharedHandle>& dataCacheHandle, const size_t dataSize)
-{
- bool valid = true;
- // DataCacheHandle size should always be 1 for ArmNN model
- if (dataCacheHandle.size() != 1)
- {
- return !valid;
- }
-
- if (dataSize == 0)
- {
- return !valid;
- }
-
- struct stat statBuffer;
- if (fstat(*dataCacheHandle[0], &statBuffer) == 0)
- {
- unsigned long bufferSize = statBuffer.st_size;
- if (bufferSize != dataSize)
- {
- return !valid;
- }
- }
-
- return ValidateSharedHandle(dataCacheHandle[0]);
-}
-
GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModel(
const armnn::IRuntimePtr& runtime,
const armnn::IGpuAccTunedParametersPtr& clTunedParameters,
@@ -274,8 +256,7 @@ GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModel(
size_t hashValue = 0;
if (dataCacheHandle.size() == 1 )
{
- write(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size());
- hashValue = CacheDataHandlerInstance().Hash(dataCacheData);
+ hashValue = Hash(dataCacheData);
}
// Cache the model data
@@ -296,16 +277,20 @@ GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModel(
{
std::vector<uint8_t> modelData(modelDataSize);
pread(*modelCacheHandle[i], modelData.data(), modelData.size(), 0);
- hashValue ^= CacheDataHandlerInstance().Hash(modelData);
+ hashValue ^= Hash(modelData);
}
}
}
}
}
}
- if (hashValue != 0)
+ if (dataCacheHandle.size() == 1 && hashValue != 0)
{
- CacheDataHandlerInstance().Register(token, hashValue, dataCacheData.size());
+ std::vector<uint8_t> theHashValue(sizeof(hashValue));
+ ::memcpy(theHashValue.data(), &hashValue, sizeof(hashValue));
+
+ write(*dataCacheHandle[0], theHashValue.data(), theHashValue.size());
+ pwrite(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size(), theHashValue.size());
}
bool executeWithDummyInputs = (std::find(options.GetBackends().begin(),
@@ -374,13 +359,30 @@ GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModelFromCache(
}
// Validate dataCacheHandle
- auto dataSize = CacheDataHandlerInstance().GetCacheSize(token);
- if (!ValidateDataCacheHandle(dataCacheHandle, dataSize))
+ if (dataCacheHandle.size() != 1)
{
return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
<< "ArmnnDriverImpl::prepareModelFromCache(): Not valid data cache handle!";
}
+ if (!ValidateSharedHandle(dataCacheHandle[0]))
+ {
+ return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
+ << "ArmnnDriverImpl::prepareModelFromCache(): Not valid data cache handle!";
+ }
+
+ size_t cachedDataSize = 0;
+ struct stat dataStatBuffer;
+ if (fstat(*dataCacheHandle[0], &dataStatBuffer) == 0)
+ {
+ cachedDataSize = dataStatBuffer.st_size;
+ }
+ if (cachedDataSize == 0)
+ {
+ return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
+ << "ArmnnDriverImpl::prepareModelFromCache(): Not valid cached data!";
+ }
+
// Check if model files cached they match the expected value
unsigned int numberOfCachedModelFiles = 0;
for (auto& backend : options.GetBackends())
@@ -393,10 +395,14 @@ GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModelFromCache(
<< "ArmnnDriverImpl::prepareModelFromCache(): Model cache handle size does not match.";
}
+ // Read the hashValue
+ std::vector<uint8_t> hashValue(sizeof(size_t));
+ pread(*dataCacheHandle[0], hashValue.data(), hashValue.size(), 0);
+
// Read the model
- std::vector<uint8_t> dataCacheData(dataSize);
- pread(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size(), 0);
- auto hashValue = CacheDataHandlerInstance().Hash(dataCacheData);
+ std::vector<uint8_t> dataCacheData(cachedDataSize - hashValue.size());
+ pread(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size(), hashValue.size());
+ auto calculatedHashValue = Hash(dataCacheData);
int gpuAccCachedFd = -1;
if (modelCacheHandle.size() > 0)
@@ -423,7 +429,7 @@ GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModelFromCache(
{
std::vector<uint8_t> modelData(modelDataSize);
pread(cachedFd, modelData.data(), modelData.size(), 0);
- hashValue ^= CacheDataHandlerInstance().Hash(modelData);
+ calculatedHashValue ^= Hash(modelData);
if (backend == armnn::Compute::GpuAcc)
{
@@ -436,7 +442,9 @@ GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModelFromCache(
}
}
- if (!CacheDataHandlerInstance().Validate(token, hashValue, dataCacheData.size()))
+ std::vector<uint8_t> calculatedHashData(sizeof(calculatedHashValue));
+ ::memcpy(calculatedHashData.data(), &calculatedHashValue, sizeof(calculatedHashValue));
+ if (hashValue != calculatedHashData)
{
return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
<< "ArmnnDriverImpl::prepareModelFromCache(): ValidateHash() failed!";
@@ -529,12 +537,13 @@ GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModelFromCache(
return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << message.str();
}
- return std::make_shared<const ArmnnPreparedModel>(netId,
+ auto preparedModel = std::make_shared<const ArmnnPreparedModel>(netId,
runtime.get(),
options.GetRequestInputsAndOutputsDumpDir(),
options.IsGpuProfilingEnabled(),
Priority::MEDIUM,
true);
+ return std::move(preparedModel);
}
const Capabilities& ArmnnDriverImpl::GetCapabilities(const armnn::IRuntimePtr& runtime)