diff options
Diffstat (limited to 'reference_model/src/tensor.cc')
-rw-r--r-- | reference_model/src/tensor.cc | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index e9598c4..3cf4aa0 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -15,6 +15,7 @@ #include "tensor.h" #include "arith_util.h" +#include "array_proxy.h" #include "half.hpp" using namespace TosaReference; @@ -445,7 +446,7 @@ DEF_CTENSOR_COPY_VALUE_FROM(6, bool) #undef DEF_CTENSOR_COPY_VALUE_FROM -int TosaReference::Tensor::readfromVector(const std::vector<float>& vals) +int TosaReference::Tensor::readfromVector(const ArrayProxy<float> vals) { uint32_t elements = getElementCount(); switch (getDtype()) @@ -470,7 +471,7 @@ int TosaReference::Tensor::readfromVector(const std::vector<float>& vals) return 0; } -int TosaReference::Tensor::readfromVector(const std::vector<half_float::half>& vals) +int TosaReference::Tensor::readfromVector(const ArrayProxy<half_float::half> vals) { uint32_t elements = getElementCount(); std::vector<float> tensor(elements); @@ -502,7 +503,7 @@ int TosaReference::Tensor::readfromVector(const std::vector<half_float::half>& v return 0; } -int TosaReference::Tensor::readfromVector(const std::vector<int32_t>& vals) +int TosaReference::Tensor::readfromVector(const ArrayProxy<int32_t> vals) { uint32_t elements = getElementCount(); switch (getDtype()) @@ -531,7 +532,7 @@ int TosaReference::Tensor::readfromVector(const std::vector<int32_t>& vals) return 0; } -int TosaReference::Tensor::readfromVector(const std::vector<int64_t>& vals) +int TosaReference::Tensor::readfromVector(const ArrayProxy<int64_t> vals) { uint32_t elements = getElementCount(); switch (getDtype()) @@ -555,7 +556,7 @@ int TosaReference::Tensor::readfromVector(const std::vector<int64_t>& vals) return 0; } -int TosaReference::Tensor::readfromVector(const std::vector<unsigned char>& vals) +int TosaReference::Tensor::readfromVector(const ArrayProxy<unsigned char> vals) { uint32_t elements = getElementCount(); @@ -580,7 +581,7 @@ int TosaReference::Tensor::readfromVector(const std::vector<unsigned char>& vals return 0; } -int TosaReference::Tensor::writeToVector(std::vector<float>& vals) +int TosaReference::Tensor::writeToVector(ArrayProxy<float> vals) { uint32_t elements = getElementCount(); @@ -605,7 +606,7 @@ int TosaReference::Tensor::writeToVector(std::vector<float>& vals) return 0; } -int TosaReference::Tensor::writeToVector(std::vector<half_float::half>& vals) +int TosaReference::Tensor::writeToVector(ArrayProxy<half_float::half> vals) { uint32_t elements = getElementCount(); std::vector<float> tensor(elements); @@ -636,7 +637,7 @@ int TosaReference::Tensor::writeToVector(std::vector<half_float::half>& vals) return 0; } -int TosaReference::Tensor::writeToVector(std::vector<int32_t>& vals) +int TosaReference::Tensor::writeToVector(ArrayProxy<int32_t> vals) { uint32_t elements = getElementCount(); @@ -665,7 +666,7 @@ int TosaReference::Tensor::writeToVector(std::vector<int32_t>& vals) return 0; } -int TosaReference::Tensor::writeToVector(std::vector<int64_t>& vals) +int TosaReference::Tensor::writeToVector(ArrayProxy<int64_t> vals) { uint32_t elements = getElementCount(); @@ -689,7 +690,7 @@ int TosaReference::Tensor::writeToVector(std::vector<int64_t>& vals) return 0; } -int TosaReference::Tensor::writeToVector(std::vector<unsigned char>& vals) +int TosaReference::Tensor::writeToVector(ArrayProxy<unsigned char> vals) { uint32_t elements = getElementCount(); |