diff options
author | James Ward <james.ward@arm.com> | 2022-10-19 12:20:31 +0100 |
---|---|---|
committer | Jeremy Johnson <jeremy.johnson@arm.com> | 2022-11-09 12:19:51 +0000 |
commit | 24dbc420aae556649f50e645bd94489dab2cc75a (patch) | |
tree | 490345da43e9c5bae0f450ba05ffe85874077e0a /reference_model/src/tensor.cc | |
parent | 3b0544c1e7463295c49a48a162ebb9a546326829 (diff) | |
download | reference_model-24dbc420aae556649f50e645bd94489dab2cc75a.tar.gz |
Add BF16 support to reference model
* Upgrade Eigen to 3.4.0 (for bfloat16 support) and add work-
arounds for reduce.any() and reduce.all() bugs (introduced
between 3.3.7 and 3.4.0)
* Truncation to bfloat16 now performed in eval() methods
Signed-off-by: James Ward <james.ward@arm.com>
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: If5f5c988d76d3d30790acf3b97081726b89205fe
Diffstat (limited to 'reference_model/src/tensor.cc')
-rw-r--r-- | reference_model/src/tensor.cc | 36 |
1 files changed, 29 insertions, 7 deletions
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index 8d192ca..4eaf21d 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -90,10 +90,12 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) int64_t* i64databuf = nullptr; bool* bdatabuf = nullptr; NumpyUtilities::NPError nperror; + DType dtype = getDtype(); - switch (getDtype()) + switch (dtype) { case DType_FP32: + case DType_BF16: fdatabuf = (float*)calloc(sizeof(float), elements); ASSERT_MEM(fdatabuf); @@ -154,19 +156,38 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) FATAL_ERROR("Unknown error parsing Numpy file: %s", filename); } - switch (getDtype()) + switch (dtype) { case DType_FP16: // Convert from fp16 to fp32 + //TODO(jw): remove this once we cast to fp16 in register_fcn/eval for (uint32_t i=0; i < elements; i++) { fdatabuf[i] = half_float::half_cast<float, half_float::half>(f16databuf[i]); } - // Fall through to DType_FP32 case + if (setTensorValueFloat(elements, fdatabuf)) + { + free(f16databuf); + free(fdatabuf); + return 1; + } + break; + case DType_BF16: + for (uint32_t i=0; i < elements; i++) + { + ASSERT_MSG( + checkValidBFloat(fdatabuf[i]), + "Input float value not a valid bfloat16 value." + ); + } + if (setTensorValueFloat(elements, fdatabuf)) + { + free(fdatabuf); + return 1; + } + break; case DType_FP32: if (setTensorValueFloat(elements, fdatabuf)) { - if (f16databuf) - free(f16databuf); free(fdatabuf); return 1; } @@ -226,10 +247,12 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const bool* bdatabuf = nullptr; NumpyUtilities::NPError nperror; uint32_t elements = getElementCount(); + DType dtype = getDtype(); - switch (getDtype()) + switch (dtype) { case DType_FP32: + case DType_BF16: fdatabuf = (float*)calloc(sizeof(float), elements); ASSERT_MEM(fdatabuf); @@ -238,7 +261,6 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const free(fdatabuf); return 1; } - nperror = NumpyUtilities::writeToNpyFile(filename, shape, fdatabuf); free(fdatabuf); |