aboutsummaryrefslogtreecommitdiff
path: root/reference_model/include
diff options
context:
space:
mode:
authorWon Jeon <won.jeon@arm.com>2024-02-06 18:37:00 +0000
committerWon Jeon <won.jeon@arm.com>2024-02-21 19:38:55 +0000
commit2c34b4616a10539211e7006bc43f3c71e86c30bb (patch)
treeaa4043a610ecd4c6d35b876cfb013dbe7dd0ab01 /reference_model/include
parent587cc84c2b8c4b0d030b5e257c9a32461c0969b9 (diff)
downloadreference_model-2c34b4616a10539211e7006bc43f3c71e86c30bb.tar.gz
Add support for FP8 to reference model
Signed-off-by: Won Jeon <won.jeon@arm.com> Change-Id: I99b70f94aff2ccd4af64875697e124eb60bc5b08
Diffstat (limited to 'reference_model/include')
-rw-r--r--reference_model/include/dtype.h16
-rw-r--r--reference_model/include/types.h30
2 files changed, 30 insertions, 16 deletions
diff --git a/reference_model/include/dtype.h b/reference_model/include/dtype.h
index 1b01a0e..3e8bdf5 100644
--- a/reference_model/include/dtype.h
+++ b/reference_model/include/dtype.h
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, ARM Limited.
+// Copyright (c) 2023-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -41,6 +41,8 @@ enum TOSA_REF_TYPE : uint32_t
TOSA_REF_TYPE_FP16 = 10,
TOSA_REF_TYPE_BF16 = 11,
TOSA_REF_TYPE_SHAPE = 12,
+ TOSA_REF_TYPE_FP8E4M3 = 13,
+ TOSA_REF_TYPE_FP8E5M2 = 14,
TOSA_REF_TYPE_FP64 = 99, // FP64 is special: add new data types above
};
@@ -74,6 +76,10 @@ inline const char* EnumNameTOSAREFTYPE(TOSA_REF_TYPE e)
return EnumNameDType(DType_BF16);
case TOSA_REF_TYPE_SHAPE:
return EnumNameDType(DType_SHAPE);
+ case TOSA_REF_TYPE_FP8E4M3:
+ return EnumNameDType(DType_FP8E4M3);
+ case TOSA_REF_TYPE_FP8E5M2:
+ return EnumNameDType(DType_FP8E5M2);
case TOSA_REF_TYPE_FP64:
return "FP64";
default:
@@ -85,7 +91,7 @@ inline const char* EnumNameTOSAREFTYPE(TOSA_REF_TYPE e)
// return corresponding TOSA_REF_TYPE for DType
inline TOSA_REF_TYPE ConvertDType(const DType dtype)
{
- assert(DType_MAX == DType_SHAPE); // must update whenever DType_MAX changes
+ assert(DType_MAX == DType_FP8E5M2); // must update whenever DType_MAX changes
if (g_func_config.precise_mode)
{
@@ -95,6 +101,8 @@ inline TOSA_REF_TYPE ConvertDType(const DType dtype)
case DType_FP16:
case DType_FP32:
case DType_BF16:
+ case DType_FP8E4M3:
+ case DType_FP8E5M2:
return TOSA_REF_TYPE_FP64;
default:
break;
@@ -127,6 +135,10 @@ inline TOSA_REF_TYPE ConvertDType(const DType dtype)
return TOSA_REF_TYPE_BF16;
case DType_SHAPE:
return TOSA_REF_TYPE_SHAPE;
+ case DType_FP8E4M3:
+ return TOSA_REF_TYPE_FP8E4M3;
+ case DType_FP8E5M2:
+ return TOSA_REF_TYPE_FP8E5M2;
default:
break;
}
diff --git a/reference_model/include/types.h b/reference_model/include/types.h
index 15ee40c..32a8ce1 100644
--- a/reference_model/include/types.h
+++ b/reference_model/include/types.h
@@ -26,19 +26,21 @@ extern "C"
enum tosa_datatype_t
{
- tosa_datatype_bf16_t = 0,
- tosa_datatype_bool_t = 1,
- tosa_datatype_fp16_t = 2,
- tosa_datatype_fp32_t = 3,
- tosa_datatype_int16_t = 4,
- tosa_datatype_int32_t = 5,
- tosa_datatype_int48_t = 6,
- tosa_datatype_int4_t = 7,
- tosa_datatype_int8_t = 8,
- tosa_datatype_uint16_t = 9,
- tosa_datatype_uint8_t = 10,
- tosa_datatype_shape_t = 11,
- tosa_datatype_fp64_t = 99
+ tosa_datatype_bf16_t = 0,
+ tosa_datatype_bool_t = 1,
+ tosa_datatype_fp16_t = 2,
+ tosa_datatype_fp32_t = 3,
+ tosa_datatype_int16_t = 4,
+ tosa_datatype_int32_t = 5,
+ tosa_datatype_int48_t = 6,
+ tosa_datatype_int4_t = 7,
+ tosa_datatype_int8_t = 8,
+ tosa_datatype_uint16_t = 9,
+ tosa_datatype_uint8_t = 10,
+ tosa_datatype_shape_t = 11,
+ tosa_datatype_fp8e4m3_t = 12,
+ tosa_datatype_fp8e5m2_t = 13,
+ tosa_datatype_fp64_t = 99
};
struct tosa_tensor_t
@@ -61,4 +63,4 @@ extern "C"
}
#endif /* __cplusplus */
-#endif // TYPES_H_ \ No newline at end of file
+#endif // TYPES_H_