diff options
Diffstat (limited to 'reference_model/src/tensor.cc')
-rw-r--r-- | reference_model/src/tensor.cc | 36 |
1 files changed, 29 insertions, 7 deletions
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index 8d192ca..4eaf21d 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -90,10 +90,12 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) int64_t* i64databuf = nullptr; bool* bdatabuf = nullptr; NumpyUtilities::NPError nperror; + DType dtype = getDtype(); - switch (getDtype()) + switch (dtype) { case DType_FP32: + case DType_BF16: fdatabuf = (float*)calloc(sizeof(float), elements); ASSERT_MEM(fdatabuf); @@ -154,19 +156,38 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) FATAL_ERROR("Unknown error parsing Numpy file: %s", filename); } - switch (getDtype()) + switch (dtype) { case DType_FP16: // Convert from fp16 to fp32 + //TODO(jw): remove this once we cast to fp16 in register_fcn/eval for (uint32_t i=0; i < elements; i++) { fdatabuf[i] = half_float::half_cast<float, half_float::half>(f16databuf[i]); } - // Fall through to DType_FP32 case + if (setTensorValueFloat(elements, fdatabuf)) + { + free(f16databuf); + free(fdatabuf); + return 1; + } + break; + case DType_BF16: + for (uint32_t i=0; i < elements; i++) + { + ASSERT_MSG( + checkValidBFloat(fdatabuf[i]), + "Input float value not a valid bfloat16 value." + ); + } + if (setTensorValueFloat(elements, fdatabuf)) + { + free(fdatabuf); + return 1; + } + break; case DType_FP32: if (setTensorValueFloat(elements, fdatabuf)) { - if (f16databuf) - free(f16databuf); free(fdatabuf); return 1; } @@ -226,10 +247,12 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const bool* bdatabuf = nullptr; NumpyUtilities::NPError nperror; uint32_t elements = getElementCount(); + DType dtype = getDtype(); - switch (getDtype()) + switch (dtype) { case DType_FP32: + case DType_BF16: fdatabuf = (float*)calloc(sizeof(float), elements); ASSERT_MEM(fdatabuf); @@ -238,7 +261,6 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const free(fdatabuf); return 1; } - nperror = NumpyUtilities::writeToNpyFile(filename, shape, fdatabuf); free(fdatabuf); |