aboutsummaryrefslogtreecommitdiff
path: root/ArmnnPreparedModel_1_2.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'ArmnnPreparedModel_1_2.hpp')
-rw-r--r--ArmnnPreparedModel_1_2.hpp54
1 files changed, 39 insertions, 15 deletions
diff --git a/ArmnnPreparedModel_1_2.hpp b/ArmnnPreparedModel_1_2.hpp
index f609ef7e..e68614a0 100644
--- a/ArmnnPreparedModel_1_2.hpp
+++ b/ArmnnPreparedModel_1_2.hpp
@@ -19,18 +19,21 @@
namespace armnn_driver
{
-typedef std::function<void(::android::hardware::neuralnetworks::V1_0::ErrorStatus status,
- std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes,
- const ::android::hardware::neuralnetworks::V1_2::Timing& timing,
- std::string callingFunction)> armnnExecuteCallback_1_2;
+using CallbackAsync_1_2 = std::function<
+ void(V1_0::ErrorStatus errorStatus,
+ std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes,
+ const ::android::hardware::neuralnetworks::V1_2::Timing& timing,
+ std::string callingFunction)>;
-struct ArmnnCallback_1_2
+struct ExecutionContext_1_2
{
- armnnExecuteCallback_1_2 callback;
+ ::android::hardware::neuralnetworks::V1_2::MeasureTiming measureTimings =
+ ::android::hardware::neuralnetworks::V1_2::MeasureTiming::NO;
TimePoint driverStart;
- MeasureTiming measureTiming;
};
+using CallbackContext_1_2 = CallbackContext<CallbackAsync_1_2, ExecutionContext_1_2>;
+
template <typename HalVersion>
class ArmnnPreparedModel_1_2 : public V1_2::IPreparedModel
{
@@ -62,19 +65,38 @@ public:
configureExecutionBurst_cb cb) override;
/// execute the graph prepared from the request
- void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
- std::shared_ptr<armnn::InputTensors>& pInputTensors,
- std::shared_ptr<armnn::OutputTensors>& pOutputTensors,
- ArmnnCallback_1_2 callbackDescriptor);
+ template<typename CallbackContext>
+ bool ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+ armnn::InputTensors& inputTensors,
+ armnn::OutputTensors& outputTensors,
+ CallbackContext callback);
/// Executes this model with dummy inputs (e.g. all zeroes).
/// \return false on failure, otherwise true
bool ExecuteWithDummyInputs();
private:
- Return <V1_0::ErrorStatus> Execute(const V1_0::Request& request,
- MeasureTiming measureTiming,
- armnnExecuteCallback_1_2 callback);
+ Return<V1_0::ErrorStatus> Execute(const V1_0::Request& request,
+ MeasureTiming measureTiming,
+ CallbackAsync_1_2 callback);
+
+ Return<V1_0::ErrorStatus> PrepareMemoryForInputs(
+ armnn::InputTensors& inputs,
+ const V1_0::Request& request,
+ const std::vector<android::nn::RunTimePoolInfo>& memPools);
+
+ Return<V1_0::ErrorStatus> PrepareMemoryForOutputs(
+ armnn::OutputTensors& outputs,
+ std::vector<OutputShape> &outputShapes,
+ const V1_0::Request& request,
+ const std::vector<android::nn::RunTimePoolInfo>& memPools);
+
+ Return <V1_0::ErrorStatus> PrepareMemoryForIO(
+ armnn::InputTensors& inputs,
+ armnn::OutputTensors& outputs,
+ std::vector<android::nn::RunTimePoolInfo>& memPools,
+ const V1_0::Request& request,
+ CallbackAsync_1_2 callback);
template <typename TensorBindingCollection>
void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
@@ -84,7 +106,9 @@ private:
V1_2::Model m_Model;
// There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
// It is specific to this class, so it is declared as static here
- static RequestThread<ArmnnPreparedModel_1_2, HalVersion, ArmnnCallback_1_2> m_RequestThread;
+ static RequestThread<ArmnnPreparedModel_1_2,
+ HalVersion,
+ CallbackContext_1_2> m_RequestThread;
uint32_t m_RequestCount;
const std::string& m_RequestInputsAndOutputsDumpDir;
const bool m_GpuProfilingEnabled;