aboutsummaryrefslogtreecommitdiff
path: root/verif
diff options
context:
space:
mode:
authorWon Jeon <won.jeon@arm.com>2024-02-06 18:37:00 +0000
committerWon Jeon <won.jeon@arm.com>2024-02-21 19:38:55 +0000
commit2c34b4616a10539211e7006bc43f3c71e86c30bb (patch)
treeaa4043a610ecd4c6d35b876cfb013dbe7dd0ab01 /verif
parent587cc84c2b8c4b0d030b5e257c9a32461c0969b9 (diff)
downloadreference_model-2c34b4616a10539211e7006bc43f3c71e86c30bb.tar.gz
Add support for FP8 to reference model
Signed-off-by: Won Jeon <won.jeon@arm.com> Change-Id: I99b70f94aff2ccd4af64875697e124eb60bc5b08
Diffstat (limited to 'verif')
-rw-r--r--verif/checker/tosa_result_checker.py13
-rw-r--r--verif/conformance/tosa_main_profile_ops_info.json448
-rw-r--r--verif/generator/tosa_arg_gen.py60
-rw-r--r--verif/generator/tosa_error_if.py72
-rw-r--r--verif/generator/tosa_test_gen.py95
-rw-r--r--verif/generator/tosa_utils.py42
6 files changed, 702 insertions, 28 deletions
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py
index 212c809..4d6d345 100644
--- a/verif/checker/tosa_result_checker.py
+++ b/verif/checker/tosa_result_checker.py
@@ -13,6 +13,7 @@ from checker.color_print import print_color
from checker.verifier import VerifierError
from checker.verifier import VerifierLibrary
from generator.tosa_utils import float32_is_valid_bfloat16
+from generator.tosa_utils import float32_is_valid_float8
from schemavalidation.schemavalidation import TestDescSchemaValidator
@@ -195,6 +196,18 @@ def test_check(
"reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}"
)
return (TestResult.INCORRECT_FORMAT, 0.0, msg)
+ if "fp8e4m3" in misc_checks or "fp8e5m2" in misc_checks:
+ # Ensure floats are valid float8 values
+ test_res_is_fp8 = all([float32_is_valid_float8(f) for f in test_result.flat])
+ ref_res_is_fp8 = all(
+ [float32_is_valid_float8(f) for f in reference_result.flat]
+ )
+ if not (test_res_is_fp8 and ref_res_is_fp8):
+ msg = (
+ "All output values must be valid float8. "
+ "reference_result: {ref_res_is_float8}; test_result: {test_res_is_float8}"
+ )
+ return (TestResult.INCORRECT_FLOAT, 0.0, msg)
# for quantized test, allow +-(quantize_tolerance) error
if reference_result.dtype in (
diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json
index 7792417..7559c62 100644
--- a/verif/conformance/tosa_main_profile_ops_info.json
+++ b/verif/conformance/tosa_main_profile_ops_info.json
@@ -185,6 +185,30 @@
"2"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -233,6 +257,24 @@
"--allow-pooling-and-conv-oversizes"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -315,6 +357,30 @@
"2,65538,1,1"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -527,6 +593,30 @@
"2"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -592,6 +682,30 @@
"1,2,1,65529"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -647,6 +761,24 @@
"--allow-pooling-and-conv-oversizes"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -722,6 +854,28 @@
"--allow-pooling-and-conv-oversizes"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--target-shape",
+ "1,7,18,5,4",
+ "--target-shape",
+ "1,6,12,17,3",
+ "--tensor-dim-range",
+ "1,4",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -787,6 +941,24 @@
"--allow-pooling-and-conv-oversizes"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -840,6 +1012,30 @@
"3"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -1183,6 +1379,30 @@
"5000,1,1"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -1505,6 +1725,30 @@
"1,65538,3"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -1551,6 +1795,24 @@
"--allow-pooling-and-conv-oversizes"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -1699,6 +1961,30 @@
"1,1,65539,1"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -1889,6 +2175,30 @@
"2"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -1935,6 +2245,30 @@
"1,65535,1,2"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -2046,6 +2380,24 @@
"2989,6,1"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -2091,6 +2443,30 @@
"1,65543,2,1"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -2161,6 +2537,30 @@
"1"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -2214,6 +2614,30 @@
"1"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-shape",
+ "10,24,9,13",
+ "--target-shape",
+ "8,14,20,5",
+ "--tensor-dim-range",
+ "1,16",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -3111,6 +3535,30 @@
"2"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 7ec0cfe..d0b9eb9 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -641,6 +641,8 @@ class TosaTensorValuesGen:
DType.FP32: (1 << 128) - (1 << (127 - 23)),
DType.FP16: (1 << 16) - (1 << (15 - 10)),
DType.BF16: (1 << 128) - (1 << (127 - 7)),
+ DType.FP8E4M3: 448,
+ DType.FP8E5M2: 57344,
}
# Default lowest normal values for random numbers
@@ -648,6 +650,8 @@ class TosaTensorValuesGen:
DType.FP32: np.exp2(-126),
DType.FP16: np.exp2(-14),
DType.BF16: np.exp2(-126),
+ DType.FP8E4M3: np.exp2(-9),
+ DType.FP8E5M2: np.exp2(-16),
}
@staticmethod
@@ -715,6 +719,8 @@ class TosaTensorValuesGen:
DType.FP16,
DType.FP32,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
):
# Change from inclusive to exclusive range
data_range = (data_range[0], data_range[1] + 1)
@@ -1734,7 +1740,13 @@ class TosaArgGen:
and "data_gen" in testGen.TOSA_OP_LIST[opName]
and gtu.dtypeIsSupportedByCompliance(dtype)
):
- if dtype in [DType.FP16, DType.FP32, DType.BF16]:
+ if dtype in [
+ DType.FP16,
+ DType.FP32,
+ DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]:
dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
else:
dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
@@ -2140,6 +2152,8 @@ class TosaArgGen:
accum_dtypes = [DType.FP32]
elif dtype == DType.FP32:
accum_dtypes = [DType.FP32]
+ elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
+ accum_dtypes = [DType.FP16]
elif error_name is None:
assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
@@ -2350,7 +2364,13 @@ class TosaArgGen:
if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
pad_const_int = testGen.getRandNumberDType(dtype)
pad_const_fp = 0
- elif dtype in (DType.FP16, DType.BF16, DType.FP32):
+ elif dtype in (
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ):
pad_const_int = 0
pad_const_fp = testGen.getRandNumberDType(dtype)
else:
@@ -2468,6 +2488,8 @@ class TosaArgGen:
accum_dtypes = [DType.FP16, DType.FP32]
elif dtype == DType.BF16 or dtype == DType.FP32:
accum_dtypes = [DType.FP32]
+ elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
+ accum_dtypes = [DType.FP16]
elif error_name is None:
assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
else:
@@ -2646,11 +2668,35 @@ class TosaArgGen:
elif inDtype == DType.BOOL:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif inDtype == DType.FP16:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
+ dtypeList = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
elif inDtype == DType.BF16:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
+ dtypeList = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
elif inDtype == DType.FP32:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
+ dtypeList = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP16,
+ DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
+ elif inDtype in [DType.FP8E4M3, DType.FP8E5M2]:
+ dtypeList = [DType.FP16, DType.BF16, DType.FP32]
elif error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output type for incorrect input type
dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
@@ -3232,6 +3278,10 @@ class TosaArgGen:
outputDTypeList = [DType.BF16]
elif dtype == DType.FP32:
outputDTypeList = [DType.FP32]
+ elif dtype == DType.FP8E4M3:
+ outputDTypeList = [DType.FP8E4M3]
+ elif dtype == DType.FP8E5M2:
+ outputDTypeList = [DType.FP8E5M2]
elif error_name == ErrorIf.WrongInputType:
# If an incorrect input type is used then we set a 'correct'
# output type to avoid other errors
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 9a88acb..7a4d0d6 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -325,12 +325,32 @@ class TosaErrorIfArgGen:
@staticmethod
def eiCastErrorIf(testGen, input_dtype):
- if input_dtype in [DType.BOOL, DType.FP32]:
+ # if input_dtype in [DType.BOOL, DType.FP32]:
+ # outputDType = [DType.BOOL, DType.INT48, DType.FP32]
+ if input_dtype in [DType.BOOL]:
+ outputDType = [
+ DType.BOOL,
+ DType.INT48,
+ DType.FP32,
+ DType.FP16,
+ DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
+ elif input_dtype in [DType.FP32]:
outputDType = [DType.BOOL, DType.INT48, DType.FP32]
elif input_dtype in [DType.FP16, DType.BF16]:
outputDType = [DType.BOOL, DType.INT48]
elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
outputDType = [DType.INT48]
+ elif input_dtype in [DType.FP8E4M3, DType.FP8E5M2]:
+ outputDType = [
+ DType.BOOL,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ ]
else:
assert False, f"input_dtype ({input_dtype}) not supported"
return outputDType
@@ -476,13 +496,23 @@ class TosaErrorValidator:
)
or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
+ or (input_dtype == DType.FP8E4M3 and output_dtype != DType.FP16)
+ or (input_dtype == DType.FP8E5M2 and output_dtype != DType.FP16)
):
error_result = True
elif op["op"] == Op.ARGMAX:
if (
input_dtype
- in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
+ in [
+ DType.INT8,
+ DType.INT16,
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
and output_dtype != DType.INT32
):
error_result = True
@@ -555,12 +585,26 @@ class TosaErrorValidator:
or (
input_dtype == DType.FP16
and output_dtype
- not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
+ not in [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
)
or (
input_dtype == DType.BF16
and output_dtype
- not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
+ not in [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
)
or (
input_dtype == DType.FP32
@@ -571,6 +615,17 @@ class TosaErrorValidator:
DType.INT32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
+ )
+ or (
+ input_dtype in [DType.FP8E4M3, DType.FP8E5M2]
+ and output_dtype
+ not in [
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
]
)
):
@@ -597,6 +652,10 @@ class TosaErrorValidator:
and output_dtype != DType.FP32
or input_dtype == DType.FP32
and output_dtype != DType.FP32
+ or input_dtype == DType.FP8E4M3
+ and output_dtype != DType.FP16
+ or input_dtype == DType.FP8E5M2
+ and output_dtype != DType.FP16
):
error_result = True
# invalid input types are ignored, to avoid reporting multiple errors
@@ -2615,6 +2674,11 @@ class TosaErrorValidator:
DType.FP32,
):
error_result = True
+ elif (
+ input_dtype in (DType.FP8E4M3, DType.FP8E5M2)
+ and accum_dtype != DType.FP16
+ ):
+ error_result = True
info_dict = {
"error_name": error_name,
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 4ead982..bc931dc 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -76,7 +76,7 @@ class TosaTestGen:
return tuple(sorted(vals))
self.random_float_range = {}
- for dtype in (DType.FP32, DType.FP16, DType.BF16):
+ for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
self.random_float_range[dtype] = convertFPRange(
args.tensor_fp_value_range,
TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
@@ -152,7 +152,7 @@ class TosaTestGen:
# Returns dtype value range boundaries (low, high)
# The high boundary is excluded in the range
# unless high_inclusive is True
- if dtype in (DType.FP32, DType.FP16, DType.BF16):
+ if dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
return self.random_float_range[dtype]
elif dtype == DType.BOOL:
rng = (0, 2)
@@ -197,7 +197,13 @@ class TosaTestGen:
return np.uint8(self.rng.integers(low=low, high=high, size=shape))
elif dtype in (DType.INT48, DType.SHAPE):
return np.int64(self.rng.integers(low=low, high=high, size=shape))
- elif dtype in (DType.FP16, DType.BF16, DType.FP32):
+ elif dtype in (
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ):
f_tensor = self.rng.uniform(low=low, high=high, size=shape)
if dtype == DType.FP16:
@@ -207,6 +213,10 @@ class TosaTestGen:
if dtype == DType.BF16:
# Floor the last 16 bits of each f32 value
return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
+ elif dtype == DType.FP8E4M3:
+ return np.float32(gtu.vect_f32_to_fp8e4m3(f32_tensor))
+ elif dtype == DType.FP8E5M2:
+ return np.float32(gtu.vect_f32_to_fp8e5m2(f32_tensor))
else:
return f32_tensor
else:
@@ -266,6 +276,12 @@ class TosaTestGen:
elif dtype == DType.BF16:
rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
return gtu.vect_f32_to_bf16(rand_f32)
+ elif dtype == DType.FP8E4M3:
+ rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
+ return gtu.vect_f32_to_fp8e4m3(rand_f32)
+ elif dtype == DType.FP8E5M2:
+ rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
+ return gtu.vect_f32_to_fp8e5m2(rand_f32)
elif dtype == DType.BOOL:
return self.rng.choice([False, True])
elif dtype == DType.INT48 or dtype == DType.SHAPE:
@@ -1408,8 +1424,11 @@ class TosaTestGen:
max_val = max_val.astype(np.float32)
attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
- else:
+ elif a.dtype in (DType.INT8, DType.INT16):
attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
+ else:
+ # to avoid internal error for incorrect input types
+ attr.ClampAttribute(self.ser.builder, 0, 0, 0, 0)
self.ser.addOperator(op["op"], input_list, output_list, attr)
@@ -3190,7 +3209,13 @@ class TosaTestGen:
]
TYPE_FI16 = [DType.FP32, DType.INT16]
- TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
+ TYPE_NARROW_INT_FP = [
+ DType.INT8,
+ DType.INT16,
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ ]
# List of [Input Type 1, Input Type 2, Accumulator Type]
TYPE_CONV = [
@@ -3201,6 +3226,8 @@ class TosaTestGen:
[DType.FP16, DType.FP16, DType.FP32],
[DType.BF16, DType.BF16, DType.FP32],
[DType.FP32, DType.FP32, DType.FP32],
+ [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
+ [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
]
DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
@@ -3217,7 +3244,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
- "types": TYPE_NARROW_INT_FP,
+ "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evAxisLargerRank,
@@ -3244,7 +3271,7 @@ class TosaTestGen:
TosaArgGen.agPooling,
),
"qgen": TosaQuantGen.qgUnary,
- "types": TYPE_NARROW_INT_FP,
+ "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
"error_if_validators": (
TosaErrorValidator.evKernelSmallerOne,
@@ -3402,7 +3429,7 @@ class TosaTestGen:
TosaArgGen.agMatMul,
),
"qgen": TosaQuantGen.qgMatmul,
- "types": TYPE_NARROW_INT_FP,
+ "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evWrongRank,
@@ -3425,7 +3452,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agPooling,
),
- "types": TYPE_NARROW_INT_FP,
+ "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
"error_if_validators": (
TosaErrorValidator.evKernelSmallerOne,
@@ -4389,7 +4416,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgConcat,
TosaArgGen.agAxis,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
@@ -4413,7 +4440,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgPad,
TosaArgGen.agPad,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evPadSmallerZero,
@@ -4437,7 +4464,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
@@ -4456,7 +4483,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgReshape,
TosaArgGen.agReshape,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evTensorSizeInputOutputMismatch,
TosaErrorValidator.evWrongInputType,
@@ -4477,7 +4504,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evAxisLargerRank,
@@ -4500,7 +4527,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgSlice,
TosaArgGen.agSlice,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
# TODO Turn off these error categories for now as the reference
# model cannot allocate memory space for empty tensor. We probably
@@ -4532,7 +4559,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgTile,
TosaArgGen.agTile,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
@@ -4555,7 +4582,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agTranspose,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evIndexOutsideBounds,
TosaErrorValidator.evIndexUsedTwice,
@@ -4581,7 +4608,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
- "types": TYPE_FIB + [DType.INT48],
+ "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
"data_gen": {
"fp": (gtu.DataGenType.PSEUDO_RANDOM,),
},
@@ -4618,6 +4645,8 @@ class TosaTestGen:
DType.FP16,
DType.BF16,
DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -4640,7 +4669,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgScatter,
TosaArgGen.agNone,
),
- "types": TYPE_INT_FP,
+ "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
@@ -4709,6 +4738,8 @@ class TosaTestGen:
DType.INT16,
DType.INT32,
DType.BOOL,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -5141,6 +5172,8 @@ class OutputShaper:
DType.FP32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
]
wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
outputDType = rng.choice(wrong_dtypes)
@@ -5194,6 +5227,8 @@ class OutputShaper:
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
excludes = [DType.FP16, DType.FP32]
+ if ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
+ excludes = [DType.FP16]
else:
excludes = [out_dtype]
wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
@@ -5344,6 +5379,8 @@ class OutputShaper:
DType.FP32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
]
wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -5383,6 +5420,8 @@ class OutputShaper:
DType.FP32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
)
elif a.dtype == DType.INT16:
incorrect_types = (
@@ -5393,6 +5432,20 @@ class OutputShaper:
DType.FP32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ )
+ elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FP32,
+ DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
)
elif (
a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
@@ -5403,6 +5456,8 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
)
out_dtype = rng.choice(a=incorrect_types)
elif error_name == ErrorIf.WrongInputType:
@@ -5669,6 +5724,8 @@ class OutputShaper:
DType.FP32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
]
wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
outputDType = rng.choice(wrong_dtypes)
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 76e7388..31a0ff0 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -27,6 +27,8 @@ DTYPE_ATTRIBUTES = {
DType.FP16: {"str": "f16", "width": 16, "json": "FP16"},
DType.BF16: {"str": "bf16", "width": 16, "json": "BF16"},
DType.FP32: {"str": "f32", "width": 32, "json": "FP32"},
+ DType.FP8E4M3: {"str": "f8e4m3", "width": 8, "json": "FP8E4M3"},
+ DType.FP8E5M2: {"str": "f8e5m2", "width": 8, "json": "FP8E5M2"},
}
@@ -186,6 +188,16 @@ def get_wrong_output_type(op_name, rng, input_dtype):
DType.INT32,
DType.INT48,
)
+ elif input_dtype == DType.FP8E4M3 or input_dtype == DType.FP8E5M2:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FP32,
+ DType.BF16,
+ )
else:
# Assume all types but the input type are incorrect
incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
@@ -209,6 +221,12 @@ def float32_is_valid_bfloat16(f):
return f32_bits[16:] == "0" * 16
+def float32_is_valid_float8(f):
+ """Return True if float value is valid float8."""
+ f32_bits = get_float32_bitstring(f)
+ return f32_bits[8:] == "0" * 24
+
+
def get_float32_bitstring(f):
"""Return a big-endian string of bits representing a 32 bit float."""
f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
@@ -232,6 +250,30 @@ def float32_to_bfloat16(f):
return struct.unpack("@f", fp_bytes)[0] # native byteorder
+def float32_to_fp8e4m3(f):
+ """Turns fp32 value into fp8e4m3"""
+ f32_bits = get_float32_bitstring(f)
+ fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] + "0" * 24
+ fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
+ return struct.unpack("@f", fp_bytes)[0] # native byteorder
+
+
+def float32_to_fp8e5m2(f):
+ """Turns fp32 value into fp8e5m2"""
+ f32_bits = get_float32_bitstring(f)
+ fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] + "0" * 24
+ fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
+ return struct.unpack("@f", fp_bytes)[0]
+
+
vect_f32_to_bf16 = np.vectorize(
float32_to_bfloat16, otypes=(np.float32,)
) # NumPy vectorize: applies function to vector faster than looping
+
+vect_f32_to_fp8e4m3 = np.vectorize(
+ float32_to_fp8e4m3, otypes=(np.float32,)
+) # NumPy vectorize: applies function to vector faster than looping
+
+vect_f32_to_fp8e5m2 = np.vectorize(
+ float32_to_fp8e5m2, otypes=(np.float32,)
+) # Numpy vectorize: applies function to vector faster than looping