aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/tensor.cc
diff options
context:
space:
mode:
authorMatthew Sloyan <matthew.sloyan@arm.com>2022-10-18 18:02:48 +0100
committerMatthew Sloyan <matthew.sloyan@arm.com>2022-10-19 17:41:40 +0100
commit2e4d889fb036d1c0a34503400a3f45cfc6f9f3e1 (patch)
tree083a125f75dc4e338bf9382da538742b4cd5bfa7 /reference_model/src/tensor.cc
parent4196491afc23d375b5476b05be16defeed4eadad (diff)
downloadreference_model-2e4d889fb036d1c0a34503400a3f45cfc6f9f3e1.tar.gz
Add FP16 support to IModelRunner
* Added specific FP16 readfromVector and writeToVector methods. * Added FP16 support to float readfromVector and writeToVector methods. * Added missing reference to IModelRunner::setInput. Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com> Change-Id: I6b66468737e672afc925ccad4fb710fbb9427c14
Diffstat (limited to 'reference_model/src/tensor.cc')
-rw-r--r--reference_model/src/tensor.cc65
1 files changed, 65 insertions, 0 deletions
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index 8e65a27..8d192ca 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -429,6 +429,7 @@ int TosaReference::Tensor::readfromVector(const std::vector<float>& vals)
uint32_t elements = getElementCount();
switch (getDtype())
{
+ case DType_FP16:
case DType_FP32:
if (vals.size() != elements)
{
@@ -448,6 +449,38 @@ int TosaReference::Tensor::readfromVector(const std::vector<float>& vals)
return 0;
}
+int TosaReference::Tensor::readfromVector(const std::vector<half_float::half>& vals)
+{
+ uint32_t elements = getElementCount();
+ std::vector<float> tensor(elements);
+
+ switch (getDtype())
+ {
+ case DType_FP16:
+ if (vals.size() != elements)
+ {
+ WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
+ vals.size(), elements);
+ return -1;
+ }
+
+ // Convert from fp16 to fp32
+ for (uint32_t i=0; i < elements; i++)
+ {
+ tensor[i] = half_float::half_cast<float, half_float::half>(vals[i]);
+ }
+
+ setTensorValueFloat(elements, tensor.data());
+ break;
+ default:
+ WARNING("The input type doesn't match the data type assigned to the tensor (%s).",
+ EnumNameDType(getDtype()));
+ return -2;
+ }
+ setIsValid();
+ return 0;
+}
+
int TosaReference::Tensor::readfromVector(const std::vector<int32_t>& vals)
{
uint32_t elements = getElementCount();
@@ -532,6 +565,7 @@ int TosaReference::Tensor::writeToVector(std::vector<float>& vals)
switch (getDtype())
{
+ case DType_FP16:
case DType_FP32:
if (vals.size() != elements)
{
@@ -550,6 +584,37 @@ int TosaReference::Tensor::writeToVector(std::vector<float>& vals)
return 0;
}
+int TosaReference::Tensor::writeToVector(std::vector<half_float::half>& vals)
+{
+ uint32_t elements = getElementCount();
+ std::vector<float> tensor(elements);
+
+ switch (getDtype())
+ {
+ case DType_FP16:
+ if (vals.size() != elements)
+ {
+ WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
+ vals.size(), elements);
+ return -1;
+ }
+
+ getTensorValueFloat(elements, tensor.data());
+
+ // Convert fp32 to fp16
+ for (uint32_t i=0; i < elements; i++)
+ {
+ vals[i] = half_float::half_cast<half_float::half, float>(tensor[i]);
+ }
+ break;
+ default:
+ WARNING("The output type doesn't match the data type assigned to the tensor (%s).",
+ EnumNameDType(getDtype()));
+ return -2;
+ }
+ return 0;
+}
+
int TosaReference::Tensor::writeToVector(std::vector<int32_t>& vals)
{
uint32_t elements = getElementCount();