aboutsummaryrefslogtreecommitdiff
path: root/ArmnnPreparedModel_1_2.cpp
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2019-07-08 17:37:35 +0100
committerMike Kelly <mike.kelly@arm.com>2019-07-08 17:38:19 +0100
commit44381518586476ce7aef78b00bc6a905ddf5730a (patch)
treee1b4ea56a520f6cb6aa518550a105157e0670c33 /ArmnnPreparedModel_1_2.cpp
parenta6bc52f6b9eeddcdebf4e660b21a4409a592ac4e (diff)
downloadandroid-nn-driver-44381518586476ce7aef78b00bc6a905ddf5730a.tar.gz
IVGCVSW-3351 Run VTS tests
* Added ArmnnBurstExecutorWithCache to fix test failures. * Added support for MeasureTiming to fix test failures. Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: I12b7c6228354bac1f1a9b61ee78066219c0923ad
Diffstat (limited to 'ArmnnPreparedModel_1_2.cpp')
-rw-r--r--ArmnnPreparedModel_1_2.cpp153
1 files changed, 143 insertions, 10 deletions
diff --git a/ArmnnPreparedModel_1_2.cpp b/ArmnnPreparedModel_1_2.cpp
index f03d69d9..74da4734 100644
--- a/ArmnnPreparedModel_1_2.cpp
+++ b/ArmnnPreparedModel_1_2.cpp
@@ -20,11 +20,22 @@
using namespace android;
using namespace android::hardware;
-static const Timing g_NoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
-
namespace {
+static const Timing g_NoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
using namespace armnn_driver;
+using TimePoint = std::chrono::steady_clock::time_point;
+
+TimePoint Now()
+{
+ return std::chrono::steady_clock::now();
+}
+
+unsigned long MicrosecondsDuration(TimePoint endPoint, TimePoint startPoint)
+{
+ return static_cast<unsigned long>(std::chrono::duration_cast<std::chrono::microseconds>(
+ endPoint - startPoint).count());
+}
void NotifyCallbackAndCheck(const ::android::sp<V1_0::IExecutionCallback>& callback, ErrorStatus errorStatus,
std::string callingFunction)
@@ -167,8 +178,8 @@ Return <ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::execute_1_2(const Reque
template<typename HalVersion>
Return<void> ArmnnPreparedModel_1_2<HalVersion>::executeSynchronously(const Request& request,
- MeasureTiming,
- V1_2::IPreparedModel::executeSynchronously_cb cb)
+ MeasureTiming measureTiming,
+ executeSynchronously_cb cb)
{
ALOGV("ArmnnPreparedModel_1_2::executeSynchronously(): %s", GetModelSummary(m_Model).c_str());
m_RequestCount++;
@@ -179,8 +190,16 @@ Return<void> ArmnnPreparedModel_1_2<HalVersion>::executeSynchronously(const Requ
return Void();
}
+ TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
+
+ if (measureTiming == MeasureTiming::YES)
+ {
+ driverStart = Now();
+ }
+
if (!android::nn::validateRequest(request, m_Model))
{
+ ALOGE("ArmnnPreparedModel_1_2::executeSynchronously invalid request model");
cb(ErrorStatus::INVALID_ARGUMENT, {}, g_NoTiming);
return Void();
}
@@ -247,12 +266,21 @@ Return<void> ArmnnPreparedModel_1_2<HalVersion>::executeSynchronously(const Requ
ALOGV("ArmnnPreparedModel_1_2::executeSynchronously() before Execution");
DumpTensorsIfRequired("Input", *pInputTensors);
-
// run it
try
{
+ if (measureTiming == MeasureTiming::YES)
+ {
+ deviceStart = Now();
+ }
+
armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, *pInputTensors, *pOutputTensors);
+ if (measureTiming == MeasureTiming::YES)
+ {
+ deviceEnd = Now();
+ }
+
if (status != armnn::Status::Success)
{
ALOGW("EnqueueWorkload failed");
@@ -277,11 +305,111 @@ Return<void> ArmnnPreparedModel_1_2<HalVersion>::executeSynchronously(const Requ
pool.update();
}
ALOGV("ArmnnPreparedModel_1_2::executeSynchronously() after Execution");
- cb(ErrorStatus::NONE, {}, g_NoTiming);
+
+ if (measureTiming == MeasureTiming::YES)
+ {
+ driverEnd = Now();
+ Timing timing;
+ timing.timeOnDevice = MicrosecondsDuration(deviceEnd, deviceStart);
+ timing.timeInDriver = MicrosecondsDuration(driverEnd, driverStart);
+ ALOGV("ArmnnPreparedModel_1_2::executeSynchronously timing Device = %lu Driver = %lu", timing.timeOnDevice,
+ timing.timeInDriver);
+ cb(ErrorStatus::NONE, {}, timing);
+ }
+ else
+ {
+ cb(ErrorStatus::NONE, {}, g_NoTiming);
+ }
return Void();
}
template<typename HalVersion>
+class ArmnnBurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache
+{
+public:
+ ArmnnBurstExecutorWithCache(ArmnnPreparedModel_1_2<HalVersion>* preparedModel)
+ : m_PreparedModel(preparedModel)
+ {}
+
+ bool isCacheEntryPresent(int slot) const override
+ {
+ const auto it = m_MemoryCache.find(slot);
+ return (it != m_MemoryCache.end()) && it->second.valid();
+ }
+
+ void addCacheEntry(const hidl_memory& memory, int slot) override
+ {
+ m_MemoryCache[slot] = memory;
+ }
+
+ void removeCacheEntry(int slot) override
+ {
+ m_MemoryCache.erase(slot);
+ }
+
+ std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
+ const Request& request, const std::vector<int>& slots,
+ MeasureTiming measure) override {
+ ALOGV("ArmnnPreparedModel_1_2::BurstExecutorWithCache::execute");
+ TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
+
+ if (measure == MeasureTiming::YES)
+ {
+ driverStart = Now();
+ }
+ hidl_vec<hidl_memory> pools(slots.size());
+
+ for (int slot : slots)
+ {
+ if (!isCacheEntryPresent(slot))
+ {
+ ALOGE("ArmnnPreparedModel_1_2::BurstExecutorWithCache::no cache entry present");
+ return std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing>(ErrorStatus::INVALID_ARGUMENT,
+ {},
+ g_NoTiming);
+ }
+ pools[slot] = m_MemoryCache[slot];
+ }
+
+ Request fullRequest = request;
+ fullRequest.pools = std::move(pools);
+
+ // Setup callback
+ ErrorStatus returnedStatus = ErrorStatus::GENERAL_FAILURE;
+ hidl_vec<OutputShape> returnedOutputShapes;
+ Timing returnedTiming;
+
+ auto cb = [&returnedStatus, &returnedOutputShapes, &returnedTiming](ErrorStatus status,
+ const hidl_vec<OutputShape>& outputShapes,
+ const Timing& timing)
+ {
+ returnedStatus = status;
+ returnedOutputShapes = outputShapes;
+ returnedTiming = timing;
+ };
+
+ // Execute
+ ALOGV("ArmnnPreparedModel_1_2::BurstExecutorWithCache executing");
+ Return<void> ret = m_PreparedModel->executeSynchronously(fullRequest, measure, cb);
+
+ if (!ret.isOk() || returnedStatus != ErrorStatus::NONE)
+ {
+ ALOGE("ArmnnPreparedModel_1_2::BurstExecutorWithCache::error executing");
+ return std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing>(returnedStatus, {}, returnedTiming);
+ }
+
+ return std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing>(returnedStatus,
+ std::move(returnedOutputShapes),
+ returnedTiming);
+ }
+
+private:
+ Model m_Model;
+ ArmnnPreparedModel_1_2<HalVersion>* m_PreparedModel;
+ std::map<int, hidl_memory> m_MemoryCache;
+};
+
+template<typename HalVersion>
Return<void> ArmnnPreparedModel_1_2<HalVersion>::configureExecutionBurst(
const sp<V1_2::IBurstCallback>& callback,
const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
@@ -289,12 +417,17 @@ Return<void> ArmnnPreparedModel_1_2<HalVersion>::configureExecutionBurst(
V1_2::IPreparedModel::configureExecutionBurst_cb cb)
{
ALOGV("ArmnnPreparedModel_1_2::configureExecutionBurst");
- const sp<V1_2::IBurstContext> burst =
- ExecutionBurstServer::create(callback, requestChannel, resultChannel, this);
+ const std::shared_ptr<ArmnnBurstExecutorWithCache<HalVersion>> executorWithCache =
+ std::make_shared<ArmnnBurstExecutorWithCache<HalVersion>>(this);
+ const sp<V1_2::IBurstContext> burst = ExecutionBurstServer::create(
+ callback, requestChannel, resultChannel, executorWithCache);
- if (burst == nullptr) {
+ if (burst == nullptr)
+ {
cb(ErrorStatus::GENERAL_FAILURE, {});
- } else {
+ }
+ else
+ {
cb(ErrorStatus::NONE, burst);
}
return Void();