aboutsummaryrefslogtreecommitdiff
path: root/ArmnnPreparedModel.hpp
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2018-09-03 13:50:50 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-09-18 12:40:38 +0100
commite48bdff741568236d3c0747ad3d18a8eba5b36dd (patch)
tree77aabce6f75d86d3f2f3924f342292ae5a7267e7 /ArmnnPreparedModel.hpp
parenta15dc11fd7bf3ad49e752ec75157b731287fe46d (diff)
downloadandroid-nn-driver-e48bdff741568236d3c0747ad3d18a8eba5b36dd.tar.gz
IVGCVSW-1806 Refactored Android-NN-Driver, added common "getCapabilities",
"getSupportedOperations" and "prepareModel" implementations * Added common base ArmnnDriverImpl class * Added common template implementation of the driver's "getCapabilities", "getSupportedOperations" and "prepareModel" methods * Refactored ArmnnPreparedModel and RequestThread to support HAL v1.1 models * Moved "getStatus" to the common base class, as it is shared by both HAL implementations * Refactored the code where necessary Change-Id: I747334730026d63b4002662523fb93608f67c899
Diffstat (limited to 'ArmnnPreparedModel.hpp')
-rw-r--r--ArmnnPreparedModel.hpp21
1 files changed, 12 insertions, 9 deletions
diff --git a/ArmnnPreparedModel.hpp b/ArmnnPreparedModel.hpp
index a700e54d..86c6f5cf 100644
--- a/ArmnnPreparedModel.hpp
+++ b/ArmnnPreparedModel.hpp
@@ -8,6 +8,7 @@
#include "RequestThread.hpp"
#include "ArmnnDriver.hpp"
+#include "ArmnnDriverImpl.hpp"
#include <NeuralNetworks.h>
#include <armnn/ArmNN.hpp>
@@ -18,12 +19,15 @@
namespace armnn_driver
{
+template <typename HalVersion>
class ArmnnPreparedModel : public IPreparedModel
{
public:
+ using HalModel = typename HalVersion::Model;
+
ArmnnPreparedModel(armnn::NetworkId networkId,
armnn::IRuntime* runtime,
- const ::android::hardware::neuralnetworks::V1_0::Model& model,
+ const HalModel& model,
const std::string& requestInputsAndOutputsDumpDir,
const bool gpuProfilingEnabled);
@@ -42,19 +46,18 @@ public:
void ExecuteWithDummyInputs();
private:
-
template <typename TensorBindingCollection>
void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
- armnn::NetworkId m_NetworkId;
- armnn::IRuntime* m_Runtime;
- ::android::hardware::neuralnetworks::V1_0::Model m_Model;
+ armnn::NetworkId m_NetworkId;
+ armnn::IRuntime* m_Runtime;
+ HalModel m_Model;
// There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
// It is specific to this class, so it is declared as static here
- static RequestThread m_RequestThread;
- uint32_t m_RequestCount;
- const std::string& m_RequestInputsAndOutputsDumpDir;
- const bool m_GpuProfilingEnabled;
+ static RequestThread<HalVersion> m_RequestThread;
+ uint32_t m_RequestCount;
+ const std::string& m_RequestInputsAndOutputsDumpDir;
+ const bool m_GpuProfilingEnabled;
};
}