diff options
author | Grant Watson <grant.watson@arm.com> | 2022-11-16 15:32:39 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2022-12-15 16:41:27 +0000 |
commit | 64285a1f25e2c7b85ed1f00b7947403e92baea00 (patch) | |
tree | 6d29c54f6497741449339e808508c854ba6a2267 /reference_model/src/tensor.cc | |
parent | b45db9a696f5df7b233f374248f329c16ee7ae64 (diff) | |
download | reference_model-64285a1f25e2c7b85ed1f00b7947403e92baea00.tar.gz |
Extend reference model API with eager operator execution entrypoints
- Adds a script to generate operators.h and operators.cc
- Adds jinja2 templates for generating operators.h and operators.cc
- Adds unit tests for a subset of the operators generated
- Includes the TOSA specification as a submodule
- Adds supporting C++ and header files
Signed-off-by: Grant Watson <grant.watson@arm.com>
Change-Id: I5b60db4c56113110d8e75fe1152525d258233f9c
Diffstat (limited to 'reference_model/src/tensor.cc')
-rw-r--r-- | reference_model/src/tensor.cc | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index e9598c4..3cf4aa0 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -15,6 +15,7 @@ #include "tensor.h" #include "arith_util.h" +#include "array_proxy.h" #include "half.hpp" using namespace TosaReference; @@ -445,7 +446,7 @@ DEF_CTENSOR_COPY_VALUE_FROM(6, bool) #undef DEF_CTENSOR_COPY_VALUE_FROM -int TosaReference::Tensor::readfromVector(const std::vector<float>& vals) +int TosaReference::Tensor::readfromVector(const ArrayProxy<float> vals) { uint32_t elements = getElementCount(); switch (getDtype()) @@ -470,7 +471,7 @@ int TosaReference::Tensor::readfromVector(const std::vector<float>& vals) return 0; } -int TosaReference::Tensor::readfromVector(const std::vector<half_float::half>& vals) +int TosaReference::Tensor::readfromVector(const ArrayProxy<half_float::half> vals) { uint32_t elements = getElementCount(); std::vector<float> tensor(elements); @@ -502,7 +503,7 @@ int TosaReference::Tensor::readfromVector(const std::vector<half_float::half>& v return 0; } -int TosaReference::Tensor::readfromVector(const std::vector<int32_t>& vals) +int TosaReference::Tensor::readfromVector(const ArrayProxy<int32_t> vals) { uint32_t elements = getElementCount(); switch (getDtype()) @@ -531,7 +532,7 @@ int TosaReference::Tensor::readfromVector(const std::vector<int32_t>& vals) return 0; } -int TosaReference::Tensor::readfromVector(const std::vector<int64_t>& vals) +int TosaReference::Tensor::readfromVector(const ArrayProxy<int64_t> vals) { uint32_t elements = getElementCount(); switch (getDtype()) @@ -555,7 +556,7 @@ int TosaReference::Tensor::readfromVector(const std::vector<int64_t>& vals) return 0; } -int TosaReference::Tensor::readfromVector(const std::vector<unsigned char>& vals) +int TosaReference::Tensor::readfromVector(const ArrayProxy<unsigned char> vals) { uint32_t elements = getElementCount(); @@ -580,7 +581,7 @@ int TosaReference::Tensor::readfromVector(const std::vector<unsigned char>& vals return 0; } -int TosaReference::Tensor::writeToVector(std::vector<float>& vals) +int TosaReference::Tensor::writeToVector(ArrayProxy<float> vals) { uint32_t elements = getElementCount(); @@ -605,7 +606,7 @@ int TosaReference::Tensor::writeToVector(std::vector<float>& vals) return 0; } -int TosaReference::Tensor::writeToVector(std::vector<half_float::half>& vals) +int TosaReference::Tensor::writeToVector(ArrayProxy<half_float::half> vals) { uint32_t elements = getElementCount(); std::vector<float> tensor(elements); @@ -636,7 +637,7 @@ int TosaReference::Tensor::writeToVector(std::vector<half_float::half>& vals) return 0; } -int TosaReference::Tensor::writeToVector(std::vector<int32_t>& vals) +int TosaReference::Tensor::writeToVector(ArrayProxy<int32_t> vals) { uint32_t elements = getElementCount(); @@ -665,7 +666,7 @@ int TosaReference::Tensor::writeToVector(std::vector<int32_t>& vals) return 0; } -int TosaReference::Tensor::writeToVector(std::vector<int64_t>& vals) +int TosaReference::Tensor::writeToVector(ArrayProxy<int64_t> vals) { uint32_t elements = getElementCount(); @@ -689,7 +690,7 @@ int TosaReference::Tensor::writeToVector(std::vector<int64_t>& vals) return 0; } -int TosaReference::Tensor::writeToVector(std::vector<unsigned char>& vals) +int TosaReference::Tensor::writeToVector(ArrayProxy<unsigned char> vals) { uint32_t elements = getElementCount(); |