aboutsummaryrefslogtreecommitdiff
path: root/test/DriverTestHelpers.hpp
diff options
context:
space:
mode:
authorsurmeh01 <surabhi.mehta@arm.com>2018-05-17 14:11:25 +0100
committertelsoa01 <telmo.soares@arm.com>2018-05-23 16:23:49 +0100
commit49b9e100bfbb3b8da01472a0ff48b2bd92944e01 (patch)
tree1a998fa12f665ff0a15b299d8bae5590e0aed884 /test/DriverTestHelpers.hpp
parent28adb40e1bb1d3f3a06a7f333f7f2a4f42d3ed4b (diff)
downloadandroid-nn-driver-49b9e100bfbb3b8da01472a0ff48b2bd92944e01.tar.gz
Release 18.05
Diffstat (limited to 'test/DriverTestHelpers.hpp')
-rw-r--r--test/DriverTestHelpers.hpp135
1 files changed, 135 insertions, 0 deletions
diff --git a/test/DriverTestHelpers.hpp b/test/DriverTestHelpers.hpp
new file mode 100644
index 00000000..e90f7ecf
--- /dev/null
+++ b/test/DriverTestHelpers.hpp
@@ -0,0 +1,135 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+#pragma once
+
+#ifndef LOG_TAG
+#define LOG_TAG "ArmnnDriverTests"
+#endif // LOG_TAG
+
+#include "../ArmnnDriver.hpp"
+#include <iosfwd>
+
+namespace android
+{
+namespace hardware
+{
+namespace neuralnetworks
+{
+namespace V1_0
+{
+
+std::ostream& operator<<(std::ostream& os, ErrorStatus stat);
+
+} // namespace android::hardware::neuralnetworks::V1_0
+} // namespace android::hardware::neuralnetworks
+} // namespace android::hardware
+} // namespace android
+
+namespace driverTestHelpers
+{
+
+std::ostream& operator<<(std::ostream& os, android::hardware::neuralnetworks::V1_0::ErrorStatus stat);
+
+struct ExecutionCallback : public IExecutionCallback
+{
+ ExecutionCallback() : mNotified(false) {}
+ Return<void> notify(ErrorStatus status) override;
+ /// wait until the callback has notified us that it is done
+ Return<void> wait();
+
+private:
+ // use a mutex and a condition variable to wait for asynchronous callbacks
+ std::mutex mMutex;
+ std::condition_variable mCondition;
+ // and a flag, in case we are notified before the wait call
+ bool mNotified;
+};
+
+class PreparedModelCallback : public IPreparedModelCallback
+{
+public:
+ PreparedModelCallback()
+ : m_ErrorStatus(ErrorStatus::NONE)
+ , m_PreparedModel()
+ { }
+ ~PreparedModelCallback() override { }
+
+ Return<void> notify(ErrorStatus status,
+ const android::sp<IPreparedModel>& preparedModel) override;
+ ErrorStatus GetErrorStatus() { return m_ErrorStatus; }
+ android::sp<IPreparedModel> GetPreparedModel() { return m_PreparedModel; }
+
+private:
+ ErrorStatus m_ErrorStatus;
+ android::sp<IPreparedModel> m_PreparedModel;
+};
+
+hidl_memory allocateSharedMemory(int64_t size);
+
+android::sp<IMemory> AddPoolAndGetData(uint32_t size, Request& request);
+
+void AddPoolAndSetData(uint32_t size, Request& request, const float* data);
+
+void AddOperand(Model& model, const Operand& op);
+
+void AddIntOperand(Model& model, int32_t value);
+
+template<typename T>
+OperandType TypeToOperandType();
+
+template<>
+OperandType TypeToOperandType<float>();
+
+template<>
+OperandType TypeToOperandType<int32_t>();
+
+template<typename T>
+void AddTensorOperand(Model& model, hidl_vec<uint32_t> dimensions, T* values)
+{
+ uint32_t totalElements = 1;
+ for (uint32_t dim : dimensions)
+ {
+ totalElements *= dim;
+ }
+
+ DataLocation location = {};
+ location.offset = model.operandValues.size();
+ location.length = totalElements * sizeof(T);
+
+ Operand op = {};
+ op.type = TypeToOperandType<T>();
+ op.dimensions = dimensions;
+ op.lifetime = OperandLifeTime::CONSTANT_COPY;
+ op.location = location;
+
+ model.operandValues.resize(model.operandValues.size() + location.length);
+ for (uint32_t i = 0; i < totalElements; i++)
+ {
+ *(reinterpret_cast<T*>(&model.operandValues[location.offset]) + i) = values[i];
+ }
+
+ AddOperand(model, op);
+}
+
+void AddInputOperand(Model& model, hidl_vec<uint32_t> dimensions);
+
+void AddOutputOperand(Model& model, hidl_vec<uint32_t> dimensions);
+
+android::sp<IPreparedModel> PrepareModel(const Model& model,
+ armnn_driver::ArmnnDriver& driver);
+
+android::sp<IPreparedModel> PrepareModelWithStatus(const Model& model,
+ armnn_driver::ArmnnDriver& driver,
+ ErrorStatus & prepareStatus,
+ ErrorStatus expectedStatus=ErrorStatus::NONE);
+
+ErrorStatus Execute(android::sp<IPreparedModel> preparedModel,
+ const Request& request,
+ ErrorStatus expectedStatus=ErrorStatus::NONE);
+
+android::sp<ExecutionCallback> ExecuteNoWait(android::sp<IPreparedModel> preparedModel,
+ const Request& request);
+
+} // namespace driverTestHelpers