aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJames Ward <james.ward@arm.com>2022-11-28 16:45:36 +0000
committerEric Kunze <eric.kunze@arm.com>2022-12-15 16:43:17 +0000
commit3d3d45d669a460c6bc8e51b9dd9a8149c51e3d7f (patch)
tree8b033d86b684e79cf756c4fb3b193dc83875e378
parent7f1ea8e05537354ab2d78744616140db22ce17dc (diff)
downloadreference_model-3d3d45d669a460c6bc8e51b9dd9a8149c51e3d7f.tar.gz
Add BF16 support to IModelRunner
Signed-off-by: James Ward <james.ward@arm.com> Change-Id: I3339a78d9611905583272ffad0ef7668e046cfad
-rw-r--r--reference_model/src/tensor.cc37
1 files changed, 37 insertions, 0 deletions
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index 3cf4aa0..0678bbd 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -462,6 +462,24 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<float> vals)
setTensorValueFloat(elements, vals.data());
break;
+ case DType_BF16:
+ if (vals.size() != elements)
+ {
+ WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
+ vals.size(), elements);
+ return -1;
+ }
+
+ for (auto v : vals)
+ {
+ ASSERT_MSG(
+ checkValidBFloat(v),
+ "Input float value not a valid bfloat16 value."
+ );
+ }
+
+ setTensorValueFloat(elements, vals.data());
+ break;
default:
WARNING("The input type (float) doesn't match the data type assigned to the tensor (%s).",
EnumNameDType(getDtype()));
@@ -598,6 +616,25 @@ int TosaReference::Tensor::writeToVector(ArrayProxy<float> vals)
getTensorValueFloat(elements, vals.data());
break;
+ case DType_BF16:
+ if (vals.size() != elements)
+ {
+ WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
+ vals.size(), elements);
+ return -1;
+ }
+
+ getTensorValueFloat(elements, vals.data());
+
+ for (auto v : vals)
+ {
+ ASSERT_MSG(
+ checkValidBFloat(v),
+ "Float value not a valid bfloat16 value."
+ );
+ }
+
+ break;
default:
WARNING("The output type (float) doesn't match the data type assigned to the tensor (%s).",
EnumNameDType(getDtype()));