aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/tensor.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/tensor.cc')
-rw-r--r--reference_model/src/tensor.cc36
1 files changed, 29 insertions, 7 deletions
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index 8d192ca..4eaf21d 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -90,10 +90,12 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
int64_t* i64databuf = nullptr;
bool* bdatabuf = nullptr;
NumpyUtilities::NPError nperror;
+ DType dtype = getDtype();
- switch (getDtype())
+ switch (dtype)
{
case DType_FP32:
+ case DType_BF16:
fdatabuf = (float*)calloc(sizeof(float), elements);
ASSERT_MEM(fdatabuf);
@@ -154,19 +156,38 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
FATAL_ERROR("Unknown error parsing Numpy file: %s", filename);
}
- switch (getDtype())
+ switch (dtype)
{
case DType_FP16:
// Convert from fp16 to fp32
+ //TODO(jw): remove this once we cast to fp16 in register_fcn/eval
for (uint32_t i=0; i < elements; i++) {
fdatabuf[i] = half_float::half_cast<float, half_float::half>(f16databuf[i]);
}
- // Fall through to DType_FP32 case
+ if (setTensorValueFloat(elements, fdatabuf))
+ {
+ free(f16databuf);
+ free(fdatabuf);
+ return 1;
+ }
+ break;
+ case DType_BF16:
+ for (uint32_t i=0; i < elements; i++)
+ {
+ ASSERT_MSG(
+ checkValidBFloat(fdatabuf[i]),
+ "Input float value not a valid bfloat16 value."
+ );
+ }
+ if (setTensorValueFloat(elements, fdatabuf))
+ {
+ free(fdatabuf);
+ return 1;
+ }
+ break;
case DType_FP32:
if (setTensorValueFloat(elements, fdatabuf))
{
- if (f16databuf)
- free(f16databuf);
free(fdatabuf);
return 1;
}
@@ -226,10 +247,12 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const
bool* bdatabuf = nullptr;
NumpyUtilities::NPError nperror;
uint32_t elements = getElementCount();
+ DType dtype = getDtype();
- switch (getDtype())
+ switch (dtype)
{
case DType_FP32:
+ case DType_BF16:
fdatabuf = (float*)calloc(sizeof(float), elements);
ASSERT_MEM(fdatabuf);
@@ -238,7 +261,6 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const
free(fdatabuf);
return 1;
}
-
nperror = NumpyUtilities::writeToNpyFile(filename, shape, fdatabuf);
free(fdatabuf);