diff options
author | Grant Watson <grant.watson@arm.com> | 2022-11-16 15:32:39 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2022-12-15 16:41:27 +0000 |
commit | 64285a1f25e2c7b85ed1f00b7947403e92baea00 (patch) | |
tree | 6d29c54f6497741449339e808508c854ba6a2267 /reference_model/src/model_runner_impl.h | |
parent | b45db9a696f5df7b233f374248f329c16ee7ae64 (diff) | |
download | reference_model-64285a1f25e2c7b85ed1f00b7947403e92baea00.tar.gz |
Extend reference model API with eager operator execution entrypoints
- Adds a script to generate operators.h and operators.cc
- Adds jinja2 templates for generating operators.h and operators.cc
- Adds unit tests for a subset of the operators generated
- Includes the TOSA specification as a submodule
- Adds supporting C++ and header files
Signed-off-by: Grant Watson <grant.watson@arm.com>
Change-Id: I5b60db4c56113110d8e75fe1152525d258233f9c
Diffstat (limited to 'reference_model/src/model_runner_impl.h')
-rw-r--r-- | reference_model/src/model_runner_impl.h | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/reference_model/src/model_runner_impl.h b/reference_model/src/model_runner_impl.h index f26c484..b43370c 100644 --- a/reference_model/src/model_runner_impl.h +++ b/reference_model/src/model_runner_impl.h @@ -20,6 +20,7 @@ #include "graph_status.h" #include "version.h" +#include "array_proxy.h" #include "ops/op_factory.h" #include "subgraph_traverser.h" #include "tosa_serialization_handler.h" @@ -42,14 +43,17 @@ public: void setFuncConfig(func_config_t& func_config); void setFuncDebug(func_debug_t& func_debug); + GraphStatus initialize(TosaSerializationBasicBlock& bb); GraphStatus initialize(TosaSerializationHandler& serialization_handler); GraphStatus run(); template <typename T> - int setInput(std::string input_name, std::vector<T>& vals); + int setInput(std::string input_name, ArrayProxy<T> vals); + int setInput(std::string input_name, uint8_t* raw_ptr, size_t size); template <typename T> std::vector<T> getOutput(std::string output_name); + int getOutput(std::string output_name, uint8_t* ptr, size_t size); private: SubgraphTraverser* _main_gt = nullptr; @@ -57,6 +61,7 @@ private: // Used to determine if all input tensors have been set correctly. uint32_t n_input_tensors = 0; + GraphStatus initialize(TosaSerializationBasicBlock* bb, TosaSerializationHandler* serialization_handler); void validateTosaVersion(TosaSerializationHandler& serialization_handler); void checkGraphStatus(SubgraphTraverser& main_gt); }; |