aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py60
1 files changed, 55 insertions, 5 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 7ec0cfe..d0b9eb9 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -641,6 +641,8 @@ class TosaTensorValuesGen:
DType.FP32: (1 << 128) - (1 << (127 - 23)),
DType.FP16: (1 << 16) - (1 << (15 - 10)),
DType.BF16: (1 << 128) - (1 << (127 - 7)),
+ DType.FP8E4M3: 448,
+ DType.FP8E5M2: 57344,
}
# Default lowest normal values for random numbers
@@ -648,6 +650,8 @@ class TosaTensorValuesGen:
DType.FP32: np.exp2(-126),
DType.FP16: np.exp2(-14),
DType.BF16: np.exp2(-126),
+ DType.FP8E4M3: np.exp2(-9),
+ DType.FP8E5M2: np.exp2(-16),
}
@staticmethod
@@ -715,6 +719,8 @@ class TosaTensorValuesGen:
DType.FP16,
DType.FP32,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
):
# Change from inclusive to exclusive range
data_range = (data_range[0], data_range[1] + 1)
@@ -1734,7 +1740,13 @@ class TosaArgGen:
and "data_gen" in testGen.TOSA_OP_LIST[opName]
and gtu.dtypeIsSupportedByCompliance(dtype)
):
- if dtype in [DType.FP16, DType.FP32, DType.BF16]:
+ if dtype in [
+ DType.FP16,
+ DType.FP32,
+ DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]:
dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
else:
dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
@@ -2140,6 +2152,8 @@ class TosaArgGen:
accum_dtypes = [DType.FP32]
elif dtype == DType.FP32:
accum_dtypes = [DType.FP32]
+ elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
+ accum_dtypes = [DType.FP16]
elif error_name is None:
assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
@@ -2350,7 +2364,13 @@ class TosaArgGen:
if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
pad_const_int = testGen.getRandNumberDType(dtype)
pad_const_fp = 0
- elif dtype in (DType.FP16, DType.BF16, DType.FP32):
+ elif dtype in (
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ):
pad_const_int = 0
pad_const_fp = testGen.getRandNumberDType(dtype)
else:
@@ -2468,6 +2488,8 @@ class TosaArgGen:
accum_dtypes = [DType.FP16, DType.FP32]
elif dtype == DType.BF16 or dtype == DType.FP32:
accum_dtypes = [DType.FP32]
+ elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
+ accum_dtypes = [DType.FP16]
elif error_name is None:
assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
else:
@@ -2646,11 +2668,35 @@ class TosaArgGen:
elif inDtype == DType.BOOL:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif inDtype == DType.FP16:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
+ dtypeList = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
elif inDtype == DType.BF16:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
+ dtypeList = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
elif inDtype == DType.FP32:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
+ dtypeList = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP16,
+ DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
+ elif inDtype in [DType.FP8E4M3, DType.FP8E5M2]:
+ dtypeList = [DType.FP16, DType.BF16, DType.FP32]
elif error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output type for incorrect input type
dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
@@ -3232,6 +3278,10 @@ class TosaArgGen:
outputDTypeList = [DType.BF16]
elif dtype == DType.FP32:
outputDTypeList = [DType.FP32]
+ elif dtype == DType.FP8E4M3:
+ outputDTypeList = [DType.FP8E4M3]
+ elif dtype == DType.FP8E5M2:
+ outputDTypeList = [DType.FP8E5M2]
elif error_name == ErrorIf.WrongInputType:
# If an incorrect input type is used then we set a 'correct'
# output type to avoid other errors