aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/tensor.h
diff options
context:
space:
mode:
authorGrant Watson <grant.watson@arm.com>2022-11-16 15:32:39 +0000
committerEric Kunze <eric.kunze@arm.com>2022-12-15 16:41:27 +0000
commit64285a1f25e2c7b85ed1f00b7947403e92baea00 (patch)
tree6d29c54f6497741449339e808508c854ba6a2267 /reference_model/src/tensor.h
parentb45db9a696f5df7b233f374248f329c16ee7ae64 (diff)
downloadreference_model-64285a1f25e2c7b85ed1f00b7947403e92baea00.tar.gz
Extend reference model API with eager operator execution entrypoints
- Adds a script to generate operators.h and operators.cc - Adds jinja2 templates for generating operators.h and operators.cc - Adds unit tests for a subset of the operators generated - Includes the TOSA specification as a submodule - Adds supporting C++ and header files Signed-off-by: Grant Watson <grant.watson@arm.com> Change-Id: I5b60db4c56113110d8e75fe1152525d258233f9c
Diffstat (limited to 'reference_model/src/tensor.h')
-rw-r--r--reference_model/src/tensor.h23
1 files changed, 12 insertions, 11 deletions
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index a3ce4bb..08e865a 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -16,6 +16,7 @@
#ifndef TOSA_REFERENCE_TENSOR_H
#define TOSA_REFERENCE_TENSOR_H
+#include "array_proxy.h"
#include "model_common.h"
#include "ops/template_types.h"
#include "tosa_generated.h"
@@ -228,17 +229,17 @@ public:
virtual int writeToNpyFile(const char* filename) const;
virtual int copyValueFrom(Tensor* tensor) = 0;
- virtual int readfromVector(const std::vector<float>& vals);
- virtual int readfromVector(const std::vector<half_float::half>& vals);
- virtual int readfromVector(const std::vector<int32_t>& vals);
- virtual int readfromVector(const std::vector<int64_t>& vals);
- virtual int readfromVector(const std::vector<unsigned char>& vals);
-
- virtual int writeToVector(std::vector<float>& vals);
- virtual int writeToVector(std::vector<half_float::half>& vals);
- virtual int writeToVector(std::vector<int32_t>& vals);
- virtual int writeToVector(std::vector<int64_t>& vals);
- virtual int writeToVector(std::vector<unsigned char>& vals);
+ virtual int readfromVector(const ArrayProxy<float> vals);
+ virtual int readfromVector(const ArrayProxy<half_float::half> vals);
+ virtual int readfromVector(const ArrayProxy<int32_t> vals);
+ virtual int readfromVector(const ArrayProxy<int64_t> vals);
+ virtual int readfromVector(const ArrayProxy<unsigned char> vals);
+
+ virtual int writeToVector(ArrayProxy<float> vals);
+ virtual int writeToVector(ArrayProxy<half_float::half> vals);
+ virtual int writeToVector(ArrayProxy<int32_t> vals);
+ virtual int writeToVector(ArrayProxy<int64_t> vals);
+ virtual int writeToVector(ArrayProxy<unsigned char> vals);
const char* bool_to_str(bool in) const
{