diff options
-rw-r--r-- | reference_model/src/tensor.cc | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index 3cf4aa0..0678bbd 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -462,6 +462,24 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<float> vals) setTensorValueFloat(elements, vals.data()); break; + case DType_BF16: + if (vals.size() != elements) + { + WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + for (auto v : vals) + { + ASSERT_MSG( + checkValidBFloat(v), + "Input float value not a valid bfloat16 value." + ); + } + + setTensorValueFloat(elements, vals.data()); + break; default: WARNING("The input type (float) doesn't match the data type assigned to the tensor (%s).", EnumNameDType(getDtype())); @@ -598,6 +616,25 @@ int TosaReference::Tensor::writeToVector(ArrayProxy<float> vals) getTensorValueFloat(elements, vals.data()); break; + case DType_BF16: + if (vals.size() != elements) + { + WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + getTensorValueFloat(elements, vals.data()); + + for (auto v : vals) + { + ASSERT_MSG( + checkValidBFloat(v), + "Float value not a valid bfloat16 value." + ); + } + + break; default: WARNING("The output type (float) doesn't match the data type assigned to the tensor (%s).", EnumNameDType(getDtype())); |