From 2c34b4616a10539211e7006bc43f3c71e86c30bb Mon Sep 17 00:00:00 2001 From: Won Jeon Date: Tue, 6 Feb 2024 18:37:00 +0000 Subject: Add support for FP8 to reference model Signed-off-by: Won Jeon Change-Id: I99b70f94aff2ccd4af64875697e124eb60bc5b08 --- reference_model/include/dtype.h | 16 ++++++++++++++-- reference_model/include/types.h | 30 ++++++++++++++++-------------- 2 files changed, 30 insertions(+), 16 deletions(-) (limited to 'reference_model/include') diff --git a/reference_model/include/dtype.h b/reference_model/include/dtype.h index 1b01a0e..3e8bdf5 100644 --- a/reference_model/include/dtype.h +++ b/reference_model/include/dtype.h @@ -1,4 +1,4 @@ -// Copyright (c) 2023, ARM Limited. +// Copyright (c) 2023-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -41,6 +41,8 @@ enum TOSA_REF_TYPE : uint32_t TOSA_REF_TYPE_FP16 = 10, TOSA_REF_TYPE_BF16 = 11, TOSA_REF_TYPE_SHAPE = 12, + TOSA_REF_TYPE_FP8E4M3 = 13, + TOSA_REF_TYPE_FP8E5M2 = 14, TOSA_REF_TYPE_FP64 = 99, // FP64 is special: add new data types above }; @@ -74,6 +76,10 @@ inline const char* EnumNameTOSAREFTYPE(TOSA_REF_TYPE e) return EnumNameDType(DType_BF16); case TOSA_REF_TYPE_SHAPE: return EnumNameDType(DType_SHAPE); + case TOSA_REF_TYPE_FP8E4M3: + return EnumNameDType(DType_FP8E4M3); + case TOSA_REF_TYPE_FP8E5M2: + return EnumNameDType(DType_FP8E5M2); case TOSA_REF_TYPE_FP64: return "FP64"; default: @@ -85,7 +91,7 @@ inline const char* EnumNameTOSAREFTYPE(TOSA_REF_TYPE e) // return corresponding TOSA_REF_TYPE for DType inline TOSA_REF_TYPE ConvertDType(const DType dtype) { - assert(DType_MAX == DType_SHAPE); // must update whenever DType_MAX changes + assert(DType_MAX == DType_FP8E5M2); // must update whenever DType_MAX changes if (g_func_config.precise_mode) { @@ -95,6 +101,8 @@ inline TOSA_REF_TYPE ConvertDType(const DType dtype) case DType_FP16: case DType_FP32: case DType_BF16: + case DType_FP8E4M3: + case DType_FP8E5M2: return TOSA_REF_TYPE_FP64; default: break; @@ -127,6 +135,10 @@ inline TOSA_REF_TYPE ConvertDType(const DType dtype) return TOSA_REF_TYPE_BF16; case DType_SHAPE: return TOSA_REF_TYPE_SHAPE; + case DType_FP8E4M3: + return TOSA_REF_TYPE_FP8E4M3; + case DType_FP8E5M2: + return TOSA_REF_TYPE_FP8E5M2; default: break; } diff --git a/reference_model/include/types.h b/reference_model/include/types.h index 15ee40c..32a8ce1 100644 --- a/reference_model/include/types.h +++ b/reference_model/include/types.h @@ -26,19 +26,21 @@ extern "C" enum tosa_datatype_t { - tosa_datatype_bf16_t = 0, - tosa_datatype_bool_t = 1, - tosa_datatype_fp16_t = 2, - tosa_datatype_fp32_t = 3, - tosa_datatype_int16_t = 4, - tosa_datatype_int32_t = 5, - tosa_datatype_int48_t = 6, - tosa_datatype_int4_t = 7, - tosa_datatype_int8_t = 8, - tosa_datatype_uint16_t = 9, - tosa_datatype_uint8_t = 10, - tosa_datatype_shape_t = 11, - tosa_datatype_fp64_t = 99 + tosa_datatype_bf16_t = 0, + tosa_datatype_bool_t = 1, + tosa_datatype_fp16_t = 2, + tosa_datatype_fp32_t = 3, + tosa_datatype_int16_t = 4, + tosa_datatype_int32_t = 5, + tosa_datatype_int48_t = 6, + tosa_datatype_int4_t = 7, + tosa_datatype_int8_t = 8, + tosa_datatype_uint16_t = 9, + tosa_datatype_uint8_t = 10, + tosa_datatype_shape_t = 11, + tosa_datatype_fp8e4m3_t = 12, + tosa_datatype_fp8e5m2_t = 13, + tosa_datatype_fp64_t = 99 }; struct tosa_tensor_t @@ -61,4 +63,4 @@ extern "C" } #endif /* __cplusplus */ -#endif // TYPES_H_ \ No newline at end of file +#endif // TYPES_H_ -- cgit v1.2.1