aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/tensor.cc
diff options
context:
space:
mode:
authorMatthew Sloyan <matthew.sloyan@arm.com>2022-09-26 13:31:43 +0100
committerMatthew Sloyan <matthew.sloyan@arm.com>2022-10-07 16:02:07 +0100
commitba5fad356a926d5e1c6e0fe6b546a310230cc5a8 (patch)
tree10e3e127c091da90d591253fd55e8566c0b61e7a /reference_model/src/tensor.cc
parenta0848c6edbf37034e280a670bdd2f990fdf796da (diff)
downloadreference_model-ba5fad356a926d5e1c6e0fe6b546a310230cc5a8.tar.gz
Add IModelRunner interface to TOSA Reference Model
* Added IModelRunner interface using pimpl idiom, which allows a user to initialize, configure and run the model. * Added unit tests for IModelRunner. * Added doctest as third-party submodule. * Added user options to specify paths for dependencies. * Moved general func_config functions to separate utility, which removes cxxopts dependency. Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com> Change-Id: If42f1f82cd6dadf18911a48dcd5fa579b719aff2
Diffstat (limited to 'reference_model/src/tensor.cc')
-rw-r--r--reference_model/src/tensor.cc203
1 files changed, 203 insertions, 0 deletions
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index 90aee05..7cbeb13 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -382,6 +382,209 @@ DEF_CTENSOR_COPY_VALUE_FROM(6, bool)
#undef DEF_CTENSOR_COPY_VALUE_FROM
+int TosaReference::Tensor::readfromVector(const std::vector<float>& vals)
+{
+ uint32_t elements = getElementCount();
+ switch (getDtype())
+ {
+ case DType_FLOAT:
+ if (vals.size() != elements)
+ {
+ WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
+ vals.size(), elements);
+ return -1;
+ }
+
+ setTensorValueFloat(elements, vals.data());
+ break;
+ default:
+ WARNING("The input type (float) doesn't match the data type assigned to the tensor (%s).",
+ EnumNameDType(getDtype()));
+ return -2;
+ }
+ setIsValid();
+ return 0;
+}
+
+int TosaReference::Tensor::readfromVector(const std::vector<int32_t>& vals)
+{
+ uint32_t elements = getElementCount();
+ switch (getDtype())
+ {
+ case DType_INT32:
+ case DType_UINT8:
+ case DType_INT4:
+ case DType_INT8:
+ case DType_INT16:
+ case DType_UINT16:
+ if (vals.size() != elements)
+ {
+ WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
+ vals.size(), elements);
+ return -1;
+ }
+
+ setTensorValueInt32(elements, vals.data());
+ break;
+ default:
+ WARNING("The input type doesn't match the data type assigned to the tensor (%s).",
+ EnumNameDType(getDtype()));
+ return -2;
+ }
+ setIsValid();
+ return 0;
+}
+
+int TosaReference::Tensor::readfromVector(const std::vector<int64_t>& vals)
+{
+ uint32_t elements = getElementCount();
+ switch (getDtype())
+ {
+ case DType_INT48:
+ if (vals.size() != elements)
+ {
+ WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
+ vals.size(), elements);
+ return -1;
+ }
+
+ setTensorValueInt64(elements, vals.data());
+ break;
+ default:
+ WARNING("The input type doesn't match the data type assigned to the tensor (%s).",
+ EnumNameDType(getDtype()));
+ return -2;
+ }
+ setIsValid();
+ return 0;
+}
+
+int TosaReference::Tensor::readfromVector(const std::vector<unsigned char>& vals)
+{
+ uint32_t elements = getElementCount();
+
+ switch (getDtype())
+ {
+ case DType_BOOL:
+ if (vals.size() != elements)
+ {
+ WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
+ vals.size(), elements);
+ return -1;
+ }
+
+ setTensorValueBool(elements, reinterpret_cast<const bool*>(vals.data()));
+ break;
+ default:
+ WARNING("The input type (bool) doesn't match the data type assigned to the tensor (%s).",
+ EnumNameDType(getDtype()));
+ return -2;
+ }
+ setIsValid();
+ return 0;
+}
+
+int TosaReference::Tensor::writeToVector(std::vector<float>& vals)
+{
+ uint32_t elements = getElementCount();
+
+ switch (getDtype())
+ {
+ case DType_FLOAT:
+ if (vals.size() != elements)
+ {
+ WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
+ vals.size(), elements);
+ return -1;
+ }
+
+ getTensorValueFloat(elements, vals.data());
+ break;
+ default:
+ WARNING("The output type (float) doesn't match the data type assigned to the tensor (%s).",
+ EnumNameDType(getDtype()));
+ return -2;
+ }
+ return 0;
+}
+
+int TosaReference::Tensor::writeToVector(std::vector<int32_t>& vals)
+{
+ uint32_t elements = getElementCount();
+
+ switch (getDtype())
+ {
+ case DType_INT32:
+ case DType_UINT8:
+ case DType_INT4:
+ case DType_INT8:
+ case DType_INT16:
+ case DType_UINT16:
+ if (vals.size() != elements)
+ {
+ WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
+ vals.size(), elements);
+ return -1;
+ }
+
+ getTensorValueInt32(elements, vals.data());
+ break;
+ default:
+ WARNING("The output type doesn't match the data type assigned to the tensor (%s).",
+ EnumNameDType(getDtype()));
+ return -2;
+ }
+ return 0;
+}
+
+int TosaReference::Tensor::writeToVector(std::vector<int64_t>& vals)
+{
+ uint32_t elements = getElementCount();
+
+ switch (getDtype())
+ {
+ case tosa::DType_INT48:
+ if (vals.size() != elements)
+ {
+ WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
+ vals.size(), elements);
+ return -1;
+ }
+
+ getTensorValueInt64(elements, vals.data());
+ break;
+ default:
+ WARNING("The output type doesn't match the data type assigned to the tensor (%s).",
+ EnumNameDType(getDtype()));
+ return -2;
+ }
+ return 0;
+}
+
+int TosaReference::Tensor::writeToVector(std::vector<unsigned char>& vals)
+{
+ uint32_t elements = getElementCount();
+
+ switch (getDtype())
+ {
+ case tosa::DType_BOOL:
+ if (vals.size() != elements)
+ {
+ WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
+ vals.size(), elements);
+ return -1;
+ }
+
+ getTensorValueBool(elements, reinterpret_cast<bool*>(vals.data()));
+ break;
+ default:
+ WARNING("The output type (bool) doesn't match the data type assigned to the tensor (%s).",
+ EnumNameDType(getDtype()));
+ return -2;
+ }
+ return 0;
+}
+
template <class T>
int TosaReference::TensorTemplate<T>::setTensorValueFloat(const size_t buflen, const float* vals)
{