aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-11-15 16:25:45 +0000
committerEric Kunze <eric.kunze@arm.com>2023-11-30 18:52:24 +0000
commit708da823504b9a7f4e2ffc10e00f06bb092ce637 (patch)
treeaccbf5aaf055cb07d60fec3c14b7001a8c0fc710 /verif/generator/tosa_arg_gen.py
parent3047625f7d4b3a77cb3a3480481122f7ba01be2d (diff)
downloadreference_model-708da823504b9a7f4e2ffc10e00f06bb092ce637.tar.gz
Main Compliance testing support for CAST
Limit CAST input tensor to maximums of output type to avoid saturation and infinity. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I33350a4ce0ec828da7d2e7aa8cd3183a89a97431
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py34
1 files changed, 33 insertions, 1 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 3057963..c557207 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1434,6 +1434,27 @@ class TosaTensorValuesGen:
testGen, opName, dtypeList, shapeList, argsDict, error_name
)
+ @staticmethod
+ def tvgCast(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
+ in_dtype = dtypeList[0]
+ out_dtype = argsDict["out_type"]
+ # Create look up to limit input tensor to output type maximums to avoid
+ # FP infinities and saturation of integers
+ out_range = testGen.getDTypeRange(out_dtype, high_inclusive=True)
+ highval_lookup = {in_dtype: out_range[1]}
+ data_range = TosaTensorValuesGen._get_data_range(
+ testGen,
+ in_dtype,
+ highval_lookup,
+ )
+
+ assert data_range is not None
+ argsDict["data_range"] = data_range
+
+ return TosaTensorValuesGen.tvgLazyGenDefault(
+ testGen, opName, dtypeList, shapeList, argsDict, error_name
+ )
+
class TosaArgGen:
"""Argument generators create exhaustive or random lists of attributes for
@@ -2350,7 +2371,18 @@ class TosaArgGen:
raise Exception("Unexpected input dtype: {}".format(inDtype))
for dtype in dtypeList:
- arg_list.append(("out{}".format(testGen.typeStr(dtype)), [dtype]))
+ arg_list.append(
+ ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype})
+ )
+
+ # Now add data generator types
+ arg_list = TosaArgGen._add_data_generators(
+ testGen,
+ opName,
+ dtype,
+ arg_list,
+ error_name,
+ )
return arg_list