aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/tensor.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/tensor.cc')
-rw-r--r--reference_model/src/tensor.cc21
1 files changed, 11 insertions, 10 deletions
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index e9598c4..3cf4aa0 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -15,6 +15,7 @@
#include "tensor.h"
#include "arith_util.h"
+#include "array_proxy.h"
#include "half.hpp"
using namespace TosaReference;
@@ -445,7 +446,7 @@ DEF_CTENSOR_COPY_VALUE_FROM(6, bool)
#undef DEF_CTENSOR_COPY_VALUE_FROM
-int TosaReference::Tensor::readfromVector(const std::vector<float>& vals)
+int TosaReference::Tensor::readfromVector(const ArrayProxy<float> vals)
{
uint32_t elements = getElementCount();
switch (getDtype())
@@ -470,7 +471,7 @@ int TosaReference::Tensor::readfromVector(const std::vector<float>& vals)
return 0;
}
-int TosaReference::Tensor::readfromVector(const std::vector<half_float::half>& vals)
+int TosaReference::Tensor::readfromVector(const ArrayProxy<half_float::half> vals)
{
uint32_t elements = getElementCount();
std::vector<float> tensor(elements);
@@ -502,7 +503,7 @@ int TosaReference::Tensor::readfromVector(const std::vector<half_float::half>& v
return 0;
}
-int TosaReference::Tensor::readfromVector(const std::vector<int32_t>& vals)
+int TosaReference::Tensor::readfromVector(const ArrayProxy<int32_t> vals)
{
uint32_t elements = getElementCount();
switch (getDtype())
@@ -531,7 +532,7 @@ int TosaReference::Tensor::readfromVector(const std::vector<int32_t>& vals)
return 0;
}
-int TosaReference::Tensor::readfromVector(const std::vector<int64_t>& vals)
+int TosaReference::Tensor::readfromVector(const ArrayProxy<int64_t> vals)
{
uint32_t elements = getElementCount();
switch (getDtype())
@@ -555,7 +556,7 @@ int TosaReference::Tensor::readfromVector(const std::vector<int64_t>& vals)
return 0;
}
-int TosaReference::Tensor::readfromVector(const std::vector<unsigned char>& vals)
+int TosaReference::Tensor::readfromVector(const ArrayProxy<unsigned char> vals)
{
uint32_t elements = getElementCount();
@@ -580,7 +581,7 @@ int TosaReference::Tensor::readfromVector(const std::vector<unsigned char>& vals
return 0;
}
-int TosaReference::Tensor::writeToVector(std::vector<float>& vals)
+int TosaReference::Tensor::writeToVector(ArrayProxy<float> vals)
{
uint32_t elements = getElementCount();
@@ -605,7 +606,7 @@ int TosaReference::Tensor::writeToVector(std::vector<float>& vals)
return 0;
}
-int TosaReference::Tensor::writeToVector(std::vector<half_float::half>& vals)
+int TosaReference::Tensor::writeToVector(ArrayProxy<half_float::half> vals)
{
uint32_t elements = getElementCount();
std::vector<float> tensor(elements);
@@ -636,7 +637,7 @@ int TosaReference::Tensor::writeToVector(std::vector<half_float::half>& vals)
return 0;
}
-int TosaReference::Tensor::writeToVector(std::vector<int32_t>& vals)
+int TosaReference::Tensor::writeToVector(ArrayProxy<int32_t> vals)
{
uint32_t elements = getElementCount();
@@ -665,7 +666,7 @@ int TosaReference::Tensor::writeToVector(std::vector<int32_t>& vals)
return 0;
}
-int TosaReference::Tensor::writeToVector(std::vector<int64_t>& vals)
+int TosaReference::Tensor::writeToVector(ArrayProxy<int64_t> vals)
{
uint32_t elements = getElementCount();
@@ -689,7 +690,7 @@ int TosaReference::Tensor::writeToVector(std::vector<int64_t>& vals)
return 0;
}
-int TosaReference::Tensor::writeToVector(std::vector<unsigned char>& vals)
+int TosaReference::Tensor::writeToVector(ArrayProxy<unsigned char> vals)
{
uint32_t elements = getElementCount();