aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/model_runner_impl.h
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/model_runner_impl.h')
-rw-r--r--reference_model/src/model_runner_impl.h7
1 files changed, 6 insertions, 1 deletions
diff --git a/reference_model/src/model_runner_impl.h b/reference_model/src/model_runner_impl.h
index f26c484..b43370c 100644
--- a/reference_model/src/model_runner_impl.h
+++ b/reference_model/src/model_runner_impl.h
@@ -20,6 +20,7 @@
#include "graph_status.h"
#include "version.h"
+#include "array_proxy.h"
#include "ops/op_factory.h"
#include "subgraph_traverser.h"
#include "tosa_serialization_handler.h"
@@ -42,14 +43,17 @@ public:
void setFuncConfig(func_config_t& func_config);
void setFuncDebug(func_debug_t& func_debug);
+ GraphStatus initialize(TosaSerializationBasicBlock& bb);
GraphStatus initialize(TosaSerializationHandler& serialization_handler);
GraphStatus run();
template <typename T>
- int setInput(std::string input_name, std::vector<T>& vals);
+ int setInput(std::string input_name, ArrayProxy<T> vals);
+ int setInput(std::string input_name, uint8_t* raw_ptr, size_t size);
template <typename T>
std::vector<T> getOutput(std::string output_name);
+ int getOutput(std::string output_name, uint8_t* ptr, size_t size);
private:
SubgraphTraverser* _main_gt = nullptr;
@@ -57,6 +61,7 @@ private:
// Used to determine if all input tensors have been set correctly.
uint32_t n_input_tensors = 0;
+ GraphStatus initialize(TosaSerializationBasicBlock* bb, TosaSerializationHandler* serialization_handler);
void validateTosaVersion(TosaSerializationHandler& serialization_handler);
void checkGraphStatus(SubgraphTraverser& main_gt);
};