diff options
Diffstat (limited to 'reference_model/src/tensor.cc')
-rw-r--r-- | reference_model/src/tensor.cc | 34 |
1 files changed, 33 insertions, 1 deletions
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index f9ec937..27f21f3 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -115,12 +115,15 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) assert(dtype == ConvertDType(serialization_dtype)); // if dtype is FP64, serialization_dtype must be one of FP32, FP16, BF16 assert(dtype != TOSA_REF_TYPE_FP64 || serialization_dtype == DType_FP32 || serialization_dtype == DType_FP16 || - serialization_dtype == DType_BF16); + serialization_dtype == DType_BF16 || serialization_dtype == DType_FP8E4M3 || + serialization_dtype == DType_FP8E5M2); switch (serialization_dtype) { case DType_FP32: case DType_BF16: + case DType_FP8E4M3: + case DType_FP8E5M2: f32databuf = (float*)calloc(sizeof(float), elements); ASSERT_MEM(f32databuf); @@ -208,6 +211,14 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) return 1; } break; + case TOSA_REF_TYPE_FP8E4M3: + case TOSA_REF_TYPE_FP8E5M2: + if (setTensorValueFloat(elements, f32databuf)) + { + free(f32databuf); + return 1; + } + break; case TOSA_REF_TYPE_FP32: if (setTensorValueFloat(elements, f32databuf)) { @@ -276,6 +287,23 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) return 1; } break; + case DType_FP8E4M3: + case DType_FP8E5M2: + // FP8E4M3 -> FP64 + f64databuf = (double*)calloc(sizeof(double), elements); + ASSERT_MEM(f64databuf); + for (uint32_t i = 0; i < elements; i++) + { + //ASSERT_MSG(checkValidFloat8(f32databuf[i]), "Input float value not a valid float8 value."); + f64databuf[i] = static_cast<double>(f32databuf[i]); + } + if (setTensorValueDouble(elements, f64databuf)) + { + free(f32databuf); + free(f64databuf); + return 1; + } + break; case DType_FP32: // FP32 -> FP64 f64databuf = (double*)calloc(sizeof(double), elements); @@ -349,6 +377,8 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const free(f32databuf); break; case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_FP8E4M3: + case TOSA_REF_TYPE_FP8E5M2: f32databuf = (float*)calloc(sizeof(float), elements); ASSERT_MEM(f32databuf); f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements); @@ -631,6 +661,8 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<float> vals) // continue with setting float vals in the tensor case TOSA_REF_TYPE_FP16: case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_FP8E4M3: + case TOSA_REF_TYPE_FP8E5M2: if (vals.size() != elements) { WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", |