diff options
author | James Ward <james.ward@arm.com> | 2022-11-28 16:45:36 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2022-12-15 16:43:17 +0000 |
commit | 3d3d45d669a460c6bc8e51b9dd9a8149c51e3d7f (patch) | |
tree | 8b033d86b684e79cf756c4fb3b193dc83875e378 | |
parent | 7f1ea8e05537354ab2d78744616140db22ce17dc (diff) | |
download | reference_model-3d3d45d669a460c6bc8e51b9dd9a8149c51e3d7f.tar.gz |
Add BF16 support to IModelRunner
Signed-off-by: James Ward <james.ward@arm.com>
Change-Id: I3339a78d9611905583272ffad0ef7668e046cfad
-rw-r--r-- | reference_model/src/tensor.cc | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index 3cf4aa0..0678bbd 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -462,6 +462,24 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<float> vals) setTensorValueFloat(elements, vals.data()); break; + case DType_BF16: + if (vals.size() != elements) + { + WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + for (auto v : vals) + { + ASSERT_MSG( + checkValidBFloat(v), + "Input float value not a valid bfloat16 value." + ); + } + + setTensorValueFloat(elements, vals.data()); + break; default: WARNING("The input type (float) doesn't match the data type assigned to the tensor (%s).", EnumNameDType(getDtype())); @@ -598,6 +616,25 @@ int TosaReference::Tensor::writeToVector(ArrayProxy<float> vals) getTensorValueFloat(elements, vals.data()); break; + case DType_BF16: + if (vals.size() != elements) + { + WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + getTensorValueFloat(elements, vals.data()); + + for (auto v : vals) + { + ASSERT_MSG( + checkValidBFloat(v), + "Float value not a valid bfloat16 value." + ); + } + + break; default: WARNING("The output type (float) doesn't match the data type assigned to the tensor (%s).", EnumNameDType(getDtype())); |