aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/tensor.cc
diff options
context:
space:
mode:
authorJames Ward <james.ward@arm.com>2022-10-19 12:20:31 +0100
committerJeremy Johnson <jeremy.johnson@arm.com>2022-11-09 12:19:51 +0000
commit24dbc420aae556649f50e645bd94489dab2cc75a (patch)
tree490345da43e9c5bae0f450ba05ffe85874077e0a /reference_model/src/tensor.cc
parent3b0544c1e7463295c49a48a162ebb9a546326829 (diff)
downloadreference_model-24dbc420aae556649f50e645bd94489dab2cc75a.tar.gz
Add BF16 support to reference model
* Upgrade Eigen to 3.4.0 (for bfloat16 support) and add work- arounds for reduce.any() and reduce.all() bugs (introduced between 3.3.7 and 3.4.0) * Truncation to bfloat16 now performed in eval() methods Signed-off-by: James Ward <james.ward@arm.com> Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: If5f5c988d76d3d30790acf3b97081726b89205fe
Diffstat (limited to 'reference_model/src/tensor.cc')
-rw-r--r--reference_model/src/tensor.cc36
1 files changed, 29 insertions, 7 deletions
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index 8d192ca..4eaf21d 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -90,10 +90,12 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
int64_t* i64databuf = nullptr;
bool* bdatabuf = nullptr;
NumpyUtilities::NPError nperror;
+ DType dtype = getDtype();
- switch (getDtype())
+ switch (dtype)
{
case DType_FP32:
+ case DType_BF16:
fdatabuf = (float*)calloc(sizeof(float), elements);
ASSERT_MEM(fdatabuf);
@@ -154,19 +156,38 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
FATAL_ERROR("Unknown error parsing Numpy file: %s", filename);
}
- switch (getDtype())
+ switch (dtype)
{
case DType_FP16:
// Convert from fp16 to fp32
+ //TODO(jw): remove this once we cast to fp16 in register_fcn/eval
for (uint32_t i=0; i < elements; i++) {
fdatabuf[i] = half_float::half_cast<float, half_float::half>(f16databuf[i]);
}
- // Fall through to DType_FP32 case
+ if (setTensorValueFloat(elements, fdatabuf))
+ {
+ free(f16databuf);
+ free(fdatabuf);
+ return 1;
+ }
+ break;
+ case DType_BF16:
+ for (uint32_t i=0; i < elements; i++)
+ {
+ ASSERT_MSG(
+ checkValidBFloat(fdatabuf[i]),
+ "Input float value not a valid bfloat16 value."
+ );
+ }
+ if (setTensorValueFloat(elements, fdatabuf))
+ {
+ free(fdatabuf);
+ return 1;
+ }
+ break;
case DType_FP32:
if (setTensorValueFloat(elements, fdatabuf))
{
- if (f16databuf)
- free(f16databuf);
free(fdatabuf);
return 1;
}
@@ -226,10 +247,12 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const
bool* bdatabuf = nullptr;
NumpyUtilities::NPError nperror;
uint32_t elements = getElementCount();
+ DType dtype = getDtype();
- switch (getDtype())
+ switch (dtype)
{
case DType_FP32:
+ case DType_BF16:
fdatabuf = (float*)calloc(sizeof(float), elements);
ASSERT_MEM(fdatabuf);
@@ -238,7 +261,6 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const
free(fdatabuf);
return 1;
}
-
nperror = NumpyUtilities::writeToNpyFile(filename, shape, fdatabuf);
free(fdatabuf);