From 3195a665e3f96809a67b4cb04a57330d2bfeb0de Mon Sep 17 00:00:00 2001 From: Won Jeon Date: Tue, 27 Feb 2024 17:52:45 +0000 Subject: Fix padding value for PAD op and tensor writing to npy for FP8 Signed-off-by: Won Jeon Change-Id: I55f663c19a1d2579d24b25c7f0d476e56e7e6dd2 --- reference_model/src/ops/data_layout.cc | 4 +++- reference_model/src/tensor.cc | 6 ++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index 3bff047..b6ad704 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -183,10 +183,12 @@ int OpPad::eval() case TOSA_REF_TYPE_BF16: case TOSA_REF_TYPE_FP32: case TOSA_REF_TYPE_FP64: + case TOSA_REF_TYPE_FP8E4M3: + case TOSA_REF_TYPE_FP8E5M2: pad_value = (InEigenType)attribute->pad_const_fp(); break; default: - printNodeValidationError("Unsupported data type"); + ASSERT_MSG(false, "TOSA_REF_TYPE %s is not supported.", EnumNameTOSAREFTYPE(Dtype)); break; } diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index 16020cf..1417fed 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -289,12 +289,10 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) break; case DType_FP8E4M3: case DType_FP8E5M2: - // FP8E4M3 -> FP64 f64databuf = (double*)calloc(sizeof(double), elements); ASSERT_MEM(f64databuf); for (uint32_t i = 0; i < elements; i++) { - //ASSERT_MSG(checkValidFloat8(f32databuf[i]), "Input float value not a valid float8 value."); f64databuf[i] = static_cast(f32databuf[i]); } if (setTensorValueDouble(elements, f64databuf)) @@ -366,6 +364,8 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const { case TOSA_REF_TYPE_FP32: case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP8E4M3: + case TOSA_REF_TYPE_FP8E5M2: f32databuf = (float*)calloc(sizeof(float), elements); ASSERT_MEM(f32databuf); @@ -379,8 +379,6 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const free(f32databuf); break; case TOSA_REF_TYPE_FP16: - case TOSA_REF_TYPE_FP8E4M3: - case TOSA_REF_TYPE_FP8E5M2: f32databuf = (float*)calloc(sizeof(float), elements); ASSERT_MEM(f32databuf); f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements); -- cgit v1.2.1