aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/tensor.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/tensor.cc')
-rw-r--r--reference_model/src/tensor.cc34
1 files changed, 33 insertions, 1 deletions
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index f9ec937..27f21f3 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -115,12 +115,15 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
assert(dtype == ConvertDType(serialization_dtype));
// if dtype is FP64, serialization_dtype must be one of FP32, FP16, BF16
assert(dtype != TOSA_REF_TYPE_FP64 || serialization_dtype == DType_FP32 || serialization_dtype == DType_FP16 ||
- serialization_dtype == DType_BF16);
+ serialization_dtype == DType_BF16 || serialization_dtype == DType_FP8E4M3 ||
+ serialization_dtype == DType_FP8E5M2);
switch (serialization_dtype)
{
case DType_FP32:
case DType_BF16:
+ case DType_FP8E4M3:
+ case DType_FP8E5M2:
f32databuf = (float*)calloc(sizeof(float), elements);
ASSERT_MEM(f32databuf);
@@ -208,6 +211,14 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
return 1;
}
break;
+ case TOSA_REF_TYPE_FP8E4M3:
+ case TOSA_REF_TYPE_FP8E5M2:
+ if (setTensorValueFloat(elements, f32databuf))
+ {
+ free(f32databuf);
+ return 1;
+ }
+ break;
case TOSA_REF_TYPE_FP32:
if (setTensorValueFloat(elements, f32databuf))
{
@@ -276,6 +287,23 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
return 1;
}
break;
+ case DType_FP8E4M3:
+ case DType_FP8E5M2:
+ // FP8E4M3 -> FP64
+ f64databuf = (double*)calloc(sizeof(double), elements);
+ ASSERT_MEM(f64databuf);
+ for (uint32_t i = 0; i < elements; i++)
+ {
+ //ASSERT_MSG(checkValidFloat8(f32databuf[i]), "Input float value not a valid float8 value.");
+ f64databuf[i] = static_cast<double>(f32databuf[i]);
+ }
+ if (setTensorValueDouble(elements, f64databuf))
+ {
+ free(f32databuf);
+ free(f64databuf);
+ return 1;
+ }
+ break;
case DType_FP32:
// FP32 -> FP64
f64databuf = (double*)calloc(sizeof(double), elements);
@@ -349,6 +377,8 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const
free(f32databuf);
break;
case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_FP8E4M3:
+ case TOSA_REF_TYPE_FP8E5M2:
f32databuf = (float*)calloc(sizeof(float), elements);
ASSERT_MEM(f32databuf);
f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements);
@@ -631,6 +661,8 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<float> vals)
// continue with setting float vals in the tensor
case TOSA_REF_TYPE_FP16:
case TOSA_REF_TYPE_FP32:
+ case TOSA_REF_TYPE_FP8E4M3:
+ case TOSA_REF_TYPE_FP8E5M2:
if (vals.size() != elements)
{
WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",