diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2023-11-15 16:25:45 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2023-11-30 18:52:24 +0000 |
commit | 708da823504b9a7f4e2ffc10e00f06bb092ce637 (patch) | |
tree | accbf5aaf055cb07d60fec3c14b7001a8c0fc710 /verif/generator/tosa_arg_gen.py | |
parent | 3047625f7d4b3a77cb3a3480481122f7ba01be2d (diff) | |
download | reference_model-708da823504b9a7f4e2ffc10e00f06bb092ce637.tar.gz |
Main Compliance testing support for CAST
Limit CAST input tensor to maximums of output type to avoid
saturation and infinity.
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: I33350a4ce0ec828da7d2e7aa8cd3183a89a97431
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 34 |
1 files changed, 33 insertions, 1 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 3057963..c557207 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1434,6 +1434,27 @@ class TosaTensorValuesGen: testGen, opName, dtypeList, shapeList, argsDict, error_name ) + @staticmethod + def tvgCast(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): + in_dtype = dtypeList[0] + out_dtype = argsDict["out_type"] + # Create look up to limit input tensor to output type maximums to avoid + # FP infinities and saturation of integers + out_range = testGen.getDTypeRange(out_dtype, high_inclusive=True) + highval_lookup = {in_dtype: out_range[1]} + data_range = TosaTensorValuesGen._get_data_range( + testGen, + in_dtype, + highval_lookup, + ) + + assert data_range is not None + argsDict["data_range"] = data_range + + return TosaTensorValuesGen.tvgLazyGenDefault( + testGen, opName, dtypeList, shapeList, argsDict, error_name + ) + class TosaArgGen: """Argument generators create exhaustive or random lists of attributes for @@ -2350,7 +2371,18 @@ class TosaArgGen: raise Exception("Unexpected input dtype: {}".format(inDtype)) for dtype in dtypeList: - arg_list.append(("out{}".format(testGen.typeStr(dtype)), [dtype])) + arg_list.append( + ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype}) + ) + + # Now add data generator types + arg_list = TosaArgGen._add_data_generators( + testGen, + opName, + dtype, + arg_list, + error_name, + ) return arg_list |