aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/tensor.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/tensor.cc')
-rw-r--r--reference_model/src/tensor.cc6
1 files changed, 2 insertions, 4 deletions
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index 16020cf..1417fed 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -289,12 +289,10 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
break;
case DType_FP8E4M3:
case DType_FP8E5M2:
- // FP8E4M3 -> FP64
f64databuf = (double*)calloc(sizeof(double), elements);
ASSERT_MEM(f64databuf);
for (uint32_t i = 0; i < elements; i++)
{
- //ASSERT_MSG(checkValidFloat8(f32databuf[i]), "Input float value not a valid float8 value.");
f64databuf[i] = static_cast<double>(f32databuf[i]);
}
if (setTensorValueDouble(elements, f64databuf))
@@ -366,6 +364,8 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const
{
case TOSA_REF_TYPE_FP32:
case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP8E4M3:
+ case TOSA_REF_TYPE_FP8E5M2:
f32databuf = (float*)calloc(sizeof(float), elements);
ASSERT_MEM(f32databuf);
@@ -379,8 +379,6 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const
free(f32databuf);
break;
case TOSA_REF_TYPE_FP16:
- case TOSA_REF_TYPE_FP8E4M3:
- case TOSA_REF_TYPE_FP8E5M2:
f32databuf = (float*)calloc(sizeof(float), elements);
ASSERT_MEM(f32databuf);
f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements);