aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWon Jeon <won.jeon@arm.com>2024-02-27 17:52:45 +0000
committerWon Jeon <won.jeon@arm.com>2024-02-29 05:15:59 +0000
commit3195a665e3f96809a67b4cb04a57330d2bfeb0de (patch)
tree6ebcb1b8dbbfc9b992ac2c88095b892c8356f1e6
parent687805380e3b5aafb48ee1eb2b29857cc869f880 (diff)
downloadreference_model-3195a665e3f96809a67b4cb04a57330d2bfeb0de.tar.gz
Fix padding value for PAD op and tensor writing to npy for FP8
Signed-off-by: Won Jeon <won.jeon@arm.com> Change-Id: I55f663c19a1d2579d24b25c7f0d476e56e7e6dd2
-rw-r--r--reference_model/src/ops/data_layout.cc4
-rw-r--r--reference_model/src/tensor.cc6
2 files changed, 5 insertions, 5 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index 3bff047..b6ad704 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -183,10 +183,12 @@ int OpPad<Rank, Dtype>::eval()
case TOSA_REF_TYPE_BF16:
case TOSA_REF_TYPE_FP32:
case TOSA_REF_TYPE_FP64:
+ case TOSA_REF_TYPE_FP8E4M3:
+ case TOSA_REF_TYPE_FP8E5M2:
pad_value = (InEigenType)attribute->pad_const_fp();
break;
default:
- printNodeValidationError("Unsupported data type");
+ ASSERT_MSG(false, "TOSA_REF_TYPE %s is not supported.", EnumNameTOSAREFTYPE(Dtype));
break;
}
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index 16020cf..1417fed 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -289,12 +289,10 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
break;
case DType_FP8E4M3:
case DType_FP8E5M2:
- // FP8E4M3 -> FP64
f64databuf = (double*)calloc(sizeof(double), elements);
ASSERT_MEM(f64databuf);
for (uint32_t i = 0; i < elements; i++)
{
- //ASSERT_MSG(checkValidFloat8(f32databuf[i]), "Input float value not a valid float8 value.");
f64databuf[i] = static_cast<double>(f32databuf[i]);
}
if (setTensorValueDouble(elements, f64databuf))
@@ -366,6 +364,8 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const
{
case TOSA_REF_TYPE_FP32:
case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP8E4M3:
+ case TOSA_REF_TYPE_FP8E5M2:
f32databuf = (float*)calloc(sizeof(float), elements);
ASSERT_MEM(f32databuf);
@@ -379,8 +379,6 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const
free(f32databuf);
break;
case TOSA_REF_TYPE_FP16:
- case TOSA_REF_TYPE_FP8E4M3:
- case TOSA_REF_TYPE_FP8E5M2:
f32databuf = (float*)calloc(sizeof(float), elements);
ASSERT_MEM(f32databuf);
f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements);