aboutsummaryrefslogtreecommitdiff
path: root/reference_model/include/model_runner.h
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/include/model_runner.h')
-rw-r--r--reference_model/include/model_runner.h14
1 files changed, 14 insertions, 0 deletions
diff --git a/reference_model/include/model_runner.h b/reference_model/include/model_runner.h
index 4335794..86d0056 100644
--- a/reference_model/include/model_runner.h
+++ b/reference_model/include/model_runner.h
@@ -71,6 +71,13 @@ public:
int setInput(std::string input_name, std::vector<T>& vals);
/*
+ * Set the input tensors for the model through a raw byte buffer.
+ * The input_name much match the input tensor name in the model.
+ * NOTE: setInput() must be called for each input tensor before run() is called.
+ */
+ int setInput(std::string input_name, uint8_t* raw_ptr, size_t size);
+
+ /*
* Retrieve the output tensors from the graph after running.
* The output_name much match the output tensor name in the model.
* NOTE: run() must be called before outputs are retrieved.
@@ -78,6 +85,13 @@ public:
template <typename T>
std::vector<T> getOutput(std::string output_name);
+ /*
+ * Retrieve the output tensors from the graph after running in a raw byte buffer.
+ * The output_name much match the output tensor name in the model.
+ * NOTE: run() must be called before outputs are retrieved.
+ */
+ int getOutput(std::string output_name, uint8_t* raw_ptr, size_t size);
+
private:
std::unique_ptr<ModelRunnerImpl> model_runner_impl;
};