diff options
author | Won Jeon <won.jeon@arm.com> | 2024-02-06 18:37:00 +0000 |
---|---|---|
committer | Won Jeon <won.jeon@arm.com> | 2024-02-21 19:38:55 +0000 |
commit | 2c34b4616a10539211e7006bc43f3c71e86c30bb (patch) | |
tree | aa4043a610ecd4c6d35b876cfb013dbe7dd0ab01 /reference_model/include/dtype.h | |
parent | 587cc84c2b8c4b0d030b5e257c9a32461c0969b9 (diff) | |
download | reference_model-2c34b4616a10539211e7006bc43f3c71e86c30bb.tar.gz |
Add support for FP8 to reference model
Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: I99b70f94aff2ccd4af64875697e124eb60bc5b08
Diffstat (limited to 'reference_model/include/dtype.h')
-rw-r--r-- | reference_model/include/dtype.h | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/reference_model/include/dtype.h b/reference_model/include/dtype.h index 1b01a0e..3e8bdf5 100644 --- a/reference_model/include/dtype.h +++ b/reference_model/include/dtype.h @@ -1,4 +1,4 @@ -// Copyright (c) 2023, ARM Limited. +// Copyright (c) 2023-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -41,6 +41,8 @@ enum TOSA_REF_TYPE : uint32_t TOSA_REF_TYPE_FP16 = 10, TOSA_REF_TYPE_BF16 = 11, TOSA_REF_TYPE_SHAPE = 12, + TOSA_REF_TYPE_FP8E4M3 = 13, + TOSA_REF_TYPE_FP8E5M2 = 14, TOSA_REF_TYPE_FP64 = 99, // FP64 is special: add new data types above }; @@ -74,6 +76,10 @@ inline const char* EnumNameTOSAREFTYPE(TOSA_REF_TYPE e) return EnumNameDType(DType_BF16); case TOSA_REF_TYPE_SHAPE: return EnumNameDType(DType_SHAPE); + case TOSA_REF_TYPE_FP8E4M3: + return EnumNameDType(DType_FP8E4M3); + case TOSA_REF_TYPE_FP8E5M2: + return EnumNameDType(DType_FP8E5M2); case TOSA_REF_TYPE_FP64: return "FP64"; default: @@ -85,7 +91,7 @@ inline const char* EnumNameTOSAREFTYPE(TOSA_REF_TYPE e) // return corresponding TOSA_REF_TYPE for DType inline TOSA_REF_TYPE ConvertDType(const DType dtype) { - assert(DType_MAX == DType_SHAPE); // must update whenever DType_MAX changes + assert(DType_MAX == DType_FP8E5M2); // must update whenever DType_MAX changes if (g_func_config.precise_mode) { @@ -95,6 +101,8 @@ inline TOSA_REF_TYPE ConvertDType(const DType dtype) case DType_FP16: case DType_FP32: case DType_BF16: + case DType_FP8E4M3: + case DType_FP8E5M2: return TOSA_REF_TYPE_FP64; default: break; @@ -127,6 +135,10 @@ inline TOSA_REF_TYPE ConvertDType(const DType dtype) return TOSA_REF_TYPE_BF16; case DType_SHAPE: return TOSA_REF_TYPE_SHAPE; + case DType_FP8E4M3: + return TOSA_REF_TYPE_FP8E4M3; + case DType_FP8E5M2: + return TOSA_REF_TYPE_FP8E5M2; default: break; } |