From 3d3d45d669a460c6bc8e51b9dd9a8149c51e3d7f Mon Sep 17 00:00:00 2001 From: James Ward Date: Mon, 28 Nov 2022 16:45:36 +0000 Subject: Add BF16 support to IModelRunner Signed-off-by: James Ward Change-Id: I3339a78d9611905583272ffad0ef7668e046cfad --- reference_model/src/tensor.cc | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index 3cf4aa0..0678bbd 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -460,6 +460,24 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy vals) return -1; } + setTensorValueFloat(elements, vals.data()); + break; + case DType_BF16: + if (vals.size() != elements) + { + WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + for (auto v : vals) + { + ASSERT_MSG( + checkValidBFloat(v), + "Input float value not a valid bfloat16 value." + ); + } + setTensorValueFloat(elements, vals.data()); break; default: @@ -597,6 +615,25 @@ int TosaReference::Tensor::writeToVector(ArrayProxy vals) } getTensorValueFloat(elements, vals.data()); + break; + case DType_BF16: + if (vals.size() != elements) + { + WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + getTensorValueFloat(elements, vals.data()); + + for (auto v : vals) + { + ASSERT_MSG( + checkValidBFloat(v), + "Float value not a valid bfloat16 value." + ); + } + break; default: WARNING("The output type (float) doesn't match the data type assigned to the tensor (%s).", -- cgit v1.2.1