diff options
author | Won Jeon <won.jeon@arm.com> | 2024-02-06 18:37:00 +0000 |
---|---|---|
committer | Won Jeon <won.jeon@arm.com> | 2024-02-21 19:38:55 +0000 |
commit | 2c34b4616a10539211e7006bc43f3c71e86c30bb (patch) | |
tree | aa4043a610ecd4c6d35b876cfb013dbe7dd0ab01 /reference_model/src/tensor.cc | |
parent | 587cc84c2b8c4b0d030b5e257c9a32461c0969b9 (diff) | |
download | reference_model-2c34b4616a10539211e7006bc43f3c71e86c30bb.tar.gz |
Add support for FP8 to reference model
Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: I99b70f94aff2ccd4af64875697e124eb60bc5b08
Diffstat (limited to 'reference_model/src/tensor.cc')
-rw-r--r-- | reference_model/src/tensor.cc | 34 |
1 files changed, 33 insertions, 1 deletions
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index f9ec937..27f21f3 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -115,12 +115,15 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) assert(dtype == ConvertDType(serialization_dtype)); // if dtype is FP64, serialization_dtype must be one of FP32, FP16, BF16 assert(dtype != TOSA_REF_TYPE_FP64 || serialization_dtype == DType_FP32 || serialization_dtype == DType_FP16 || - serialization_dtype == DType_BF16); + serialization_dtype == DType_BF16 || serialization_dtype == DType_FP8E4M3 || + serialization_dtype == DType_FP8E5M2); switch (serialization_dtype) { case DType_FP32: case DType_BF16: + case DType_FP8E4M3: + case DType_FP8E5M2: f32databuf = (float*)calloc(sizeof(float), elements); ASSERT_MEM(f32databuf); @@ -208,6 +211,14 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) return 1; } break; + case TOSA_REF_TYPE_FP8E4M3: + case TOSA_REF_TYPE_FP8E5M2: + if (setTensorValueFloat(elements, f32databuf)) + { + free(f32databuf); + return 1; + } + break; case TOSA_REF_TYPE_FP32: if (setTensorValueFloat(elements, f32databuf)) { @@ -276,6 +287,23 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) return 1; } break; + case DType_FP8E4M3: + case DType_FP8E5M2: + // FP8E4M3 -> FP64 + f64databuf = (double*)calloc(sizeof(double), elements); + ASSERT_MEM(f64databuf); + for (uint32_t i = 0; i < elements; i++) + { + //ASSERT_MSG(checkValidFloat8(f32databuf[i]), "Input float value not a valid float8 value."); + f64databuf[i] = static_cast<double>(f32databuf[i]); + } + if (setTensorValueDouble(elements, f64databuf)) + { + free(f32databuf); + free(f64databuf); + return 1; + } + break; case DType_FP32: // FP32 -> FP64 f64databuf = (double*)calloc(sizeof(double), elements); @@ -349,6 +377,8 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const free(f32databuf); break; case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_FP8E4M3: + case TOSA_REF_TYPE_FP8E5M2: f32databuf = (float*)calloc(sizeof(float), elements); ASSERT_MEM(f32databuf); f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements); @@ -631,6 +661,8 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<float> vals) // continue with setting float vals in the tensor case TOSA_REF_TYPE_FP16: case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_FP8E4M3: + case TOSA_REF_TYPE_FP8E5M2: if (vals.size() != elements) { WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", |