diff options
Diffstat (limited to 'reference_model/src/model_runner_impl.h')
-rw-r--r-- | reference_model/src/model_runner_impl.h | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/reference_model/src/model_runner_impl.h b/reference_model/src/model_runner_impl.h index f26c484..b43370c 100644 --- a/reference_model/src/model_runner_impl.h +++ b/reference_model/src/model_runner_impl.h @@ -20,6 +20,7 @@ #include "graph_status.h" #include "version.h" +#include "array_proxy.h" #include "ops/op_factory.h" #include "subgraph_traverser.h" #include "tosa_serialization_handler.h" @@ -42,14 +43,17 @@ public: void setFuncConfig(func_config_t& func_config); void setFuncDebug(func_debug_t& func_debug); + GraphStatus initialize(TosaSerializationBasicBlock& bb); GraphStatus initialize(TosaSerializationHandler& serialization_handler); GraphStatus run(); template <typename T> - int setInput(std::string input_name, std::vector<T>& vals); + int setInput(std::string input_name, ArrayProxy<T> vals); + int setInput(std::string input_name, uint8_t* raw_ptr, size_t size); template <typename T> std::vector<T> getOutput(std::string output_name); + int getOutput(std::string output_name, uint8_t* ptr, size_t size); private: SubgraphTraverser* _main_gt = nullptr; @@ -57,6 +61,7 @@ private: // Used to determine if all input tensors have been set correctly. uint32_t n_input_tensors = 0; + GraphStatus initialize(TosaSerializationBasicBlock* bb, TosaSerializationHandler* serialization_handler); void validateTosaVersion(TosaSerializationHandler& serialization_handler); void checkGraphStatus(SubgraphTraverser& main_gt); }; |