aboutsummaryrefslogtreecommitdiff
path: root/ArmnnPreparedModel.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'ArmnnPreparedModel.hpp')
-rw-r--r--ArmnnPreparedModel.hpp21
1 files changed, 12 insertions, 9 deletions
diff --git a/ArmnnPreparedModel.hpp b/ArmnnPreparedModel.hpp
index a700e54d..86c6f5cf 100644
--- a/ArmnnPreparedModel.hpp
+++ b/ArmnnPreparedModel.hpp
@@ -8,6 +8,7 @@
#include "RequestThread.hpp"
#include "ArmnnDriver.hpp"
+#include "ArmnnDriverImpl.hpp"
#include <NeuralNetworks.h>
#include <armnn/ArmNN.hpp>
@@ -18,12 +19,15 @@
namespace armnn_driver
{
+template <typename HalVersion>
class ArmnnPreparedModel : public IPreparedModel
{
public:
+ using HalModel = typename HalVersion::Model;
+
ArmnnPreparedModel(armnn::NetworkId networkId,
armnn::IRuntime* runtime,
- const ::android::hardware::neuralnetworks::V1_0::Model& model,
+ const HalModel& model,
const std::string& requestInputsAndOutputsDumpDir,
const bool gpuProfilingEnabled);
@@ -42,19 +46,18 @@ public:
void ExecuteWithDummyInputs();
private:
-
template <typename TensorBindingCollection>
void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
- armnn::NetworkId m_NetworkId;
- armnn::IRuntime* m_Runtime;
- ::android::hardware::neuralnetworks::V1_0::Model m_Model;
+ armnn::NetworkId m_NetworkId;
+ armnn::IRuntime* m_Runtime;
+ HalModel m_Model;
// There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
// It is specific to this class, so it is declared as static here
- static RequestThread m_RequestThread;
- uint32_t m_RequestCount;
- const std::string& m_RequestInputsAndOutputsDumpDir;
- const bool m_GpuProfilingEnabled;
+ static RequestThread<HalVersion> m_RequestThread;
+ uint32_t m_RequestCount;
+ const std::string& m_RequestInputsAndOutputsDumpDir;
+ const bool m_GpuProfilingEnabled;
};
}