aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorevacha01 <evan.chandler@arm.com>2024-03-19 12:42:17 +0000
committerDominic Symes <dominic.symes@arm.com>2024-04-03 10:12:22 +0000
commitad8e1e25e805f6face5fcf0b3906cd06db46e1d7 (patch)
tree4b59c93f1ff6844b9bbcc744f9c15212c05f2dcf
parent12159fc6fb776908f48fbda9c74cf34980540e4f (diff)
downloadreference_model-ad8e1e25e805f6face5fcf0b3906cd06db46e1d7.tar.gz
Make Full Range FP16 tests into extra tests
Signed-off-by: evacha01 <evan.chandler@arm.com> Change-Id: I8c59ecb5a1fb53d0e9bf64333709f9e3cc908b49
-rw-r--r--reference_model/src/generate/generate_full_range.cc4
-rwxr-xr-xscripts/convert2conformance/convert2conformance.py4
-rw-r--r--verif/conformance/test_select.py21
-rw-r--r--verif/generator/tosa_arg_gen.py42
-rw-r--r--verif/generator/tosa_test_gen.py229
5 files changed, 115 insertions, 185 deletions
diff --git a/reference_model/src/generate/generate_full_range.cc b/reference_model/src/generate/generate_full_range.cc
index d2a89da..6f1deb2 100644
--- a/reference_model/src/generate/generate_full_range.cc
+++ b/reference_model/src/generate/generate_full_range.cc
@@ -41,7 +41,7 @@ bool generateFullRange(const GenerateConfig& cfg, void* data, size_t size)
// Check we support the operator
if (cfg.opType == Op::Op_UNKNOWN)
{
- WARNING("[Generator][PR] Unknown operator.");
+ WARNING("[Generator][FR] Unknown operator.");
return false;
}
@@ -52,7 +52,7 @@ bool generateFullRange(const GenerateConfig& cfg, void* data, size_t size)
return generate(cfg, outData, size);
}
default:
- WARNING("[Generator][PR] Unsupported type.");
+ WARNING("[Generator][FR] Unsupported type.");
return false;
}
}
diff --git a/scripts/convert2conformance/convert2conformance.py b/scripts/convert2conformance/convert2conformance.py
index 531dca8..4a006d6 100755
--- a/scripts/convert2conformance/convert2conformance.py
+++ b/scripts/convert2conformance/convert2conformance.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-# Copyright (c) 2021-2023, ARM Limited.
+# Copyright (c) 2021-2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
"""This script converts generated tests into conformance tests.
@@ -240,7 +240,7 @@ def update_desc_json(
# Add tags (if any)
if tags is not None:
- test_desc["tag"] = tags
+ test_desc["tag"] = test_desc.get("tag", []) + tags
return test_desc
diff --git a/verif/conformance/test_select.py b/verif/conformance/test_select.py
index 58d3f9f..e3a8ffb 100644
--- a/verif/conformance/test_select.py
+++ b/verif/conformance/test_select.py
@@ -259,11 +259,11 @@ class Operator:
negative and "ERRORIF" in str(path)
):
# Check for test set paths
- match = re.match(r"(.*)_s([0-9]+)", path.name)
+ match = re.match(r"(.*)_(s[0-9]+|full)", path.name)
if match:
- if match.group(2) == "0":
+ if match.group(2) in ["s0", "full"]:
# Only return the truncated test name
- # of the first test of a set
+ # of the first test of a set, and for full tests
yield path.with_name(match.group(1))
else:
yield path
@@ -308,11 +308,21 @@ class Operator:
paths.append(set_path)
else:
if s == 0:
- logger.error(f"Could not find test set 0 - {str(set_path)}")
+ logger.warning(f"Could not find test set 0 - {str(set_path)}")
break
s += 1
return paths
+ @staticmethod
+ def _get_extra_test_paths(path):
+ """Expand a path to find extra tests."""
+ paths = []
+ for suffix in ["full"]:
+ suffix_path = path.with_name(f"{path.name}_{suffix}")
+ if suffix_path.exists():
+ paths.append(suffix_path)
+ return paths
+
def select_tests(self): # noqa: C901 (function too complex)
"""Generate the paths to the selected tests for this operator."""
if not self.test_paths:
@@ -374,6 +384,9 @@ class Operator:
# Must be a test set - expand to all test sets
for p in Operator._get_test_set_paths(path):
yield p
+ # check for extra tests
+ for p in Operator._get_extra_test_paths(path):
+ yield p
# search for tests that match any unused parameter values
for n, path in enumerate(sorted(list(unused_paths))):
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 79d4e78..f9499b5 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1831,10 +1831,10 @@ class TosaArgGen:
and "data_gen" in testGen.TOSA_OP_LIST[opName]
and gtu.dtypeIsSupportedByCompliance(dtype)
):
- if gtu.dtypeIsFloat(dtype):
- dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
- else:
- dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
+ dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"].get(
+ dtype, (gtu.DataGenType.PSEUDO_RANDOM,)
+ )
+
else:
# Error test or No data generator types listed - assume random
dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
@@ -1843,16 +1843,7 @@ class TosaArgGen:
new_arg_list = []
for dg_type in dataGenTypesList:
for arg_str, args_dict in arg_list:
-
- if dg_type == gtu.DataGenType.FULL_RANGE:
- tensor_size = gtu.product(shapeList[0])
- if tensor_size >= gtu.DTYPE_ATTRIBUTES[dtype]["fullset"]:
- # Large enough tensor data size for full range, add a single test
- num_test_sets = 0
- else:
- # Not enough data size for full range of values, revert to random numbers
- dg_type = gtu.DataGenType.PSEUDO_RANDOM
-
+ gen_args_dict = args_dict.copy()
if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
if error_name is None:
num_test_sets = (
@@ -1883,18 +1874,31 @@ class TosaArgGen:
num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
+ elif dg_type == gtu.DataGenType.FULL_RANGE:
+ tensor_size = gtu.product(shapeList[0])
+ if tensor_size < gtu.DTYPE_ATTRIBUTES[dtype]["fullset"]:
+ shape_info = " ({})".format(shapeList[0])
+ logger.info(
+ f"Skipping {opName}{shape_info} as tensor data size too small for full range of values {tensor_size} < {gtu.DTYPE_ATTRIBUTES[dtype]['fullset']}"
+ )
+ continue
+ # Large enough tensor data size for full range, add a single test
+ num_test_sets = 0
+ arg_str = f"{arg_str}_full" if arg_str else "full"
+ gen_args_dict["tags"] = args_dict.get("tags", []) + [
+ "non_finite_fp_data"
+ ]
+
+ gen_args_dict["dg_type"] = dg_type
if num_test_sets > 0:
for s in range(0, num_test_sets):
set_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
- set_args_dict = args_dict.copy()
+ set_args_dict = gen_args_dict.copy()
set_args_dict["s"] = s
- set_args_dict["dg_type"] = dg_type
new_arg_list.append((set_arg_str, set_args_dict))
else:
# Default is a single test
- new_args_dict = args_dict.copy()
- new_args_dict["dg_type"] = dg_type
- new_arg_list.append((arg_str, new_args_dict))
+ new_arg_list.append((arg_str, gen_args_dict))
return new_arg_list
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 71d7fcc..399fed6 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -110,7 +110,7 @@ class TosaTestGen:
def getSerializer(self):
return self.ser
- def serialize(self, testName, metaData=None):
+ def serialize(self, testName, metaData=None, tags=None):
path = Path(self.basePath) / self.testPath
# Write out TOSA flatbuffer binary
@@ -125,6 +125,9 @@ class TosaTestGen:
# Add extra meta data to desc.json
desc["meta"] = metaData
+ if tags:
+ desc["tag"] = tags
+
# Validate desc.json before we output it
self.descSchemaValidator.validate_config(desc)
@@ -3146,6 +3149,8 @@ class TosaTestGen:
tensMeta["data_gen"] = tvgInfo.dataGenDict
tens = tvgInfo.tensorList
+ tags = argsDict.get("tags", None)
+
result = build_fcn(
self,
build_rng,
@@ -3164,7 +3169,7 @@ class TosaTestGen:
compliance = result.getComplianceInfo()
if compliance:
tensMeta["compliance"] = compliance
- self.serialize("test", tensMeta)
+ self.serialize("test", tensMeta, tags)
return True
else:
# The test is not valid
@@ -3326,6 +3331,18 @@ class TosaTestGen:
KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
+ PSEUDO_RANDOM_DATAGEN = {
+ DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM,),
+ DType.FP32: (gtu.DataGenType.PSEUDO_RANDOM,),
+ }
+ DOT_PRODUCT_DATAGEN = {
+ DType.FP16: (gtu.DataGenType.DOT_PRODUCT,),
+ DType.FP32: (gtu.DataGenType.DOT_PRODUCT,),
+ }
+ EW_UNARY_DATAGEN = {
+ DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FULL_RANGE)
+ }
+
TOSA_OP_LIST = {
# Tensor operators
"argmax": {
@@ -3350,9 +3367,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
"avg_pool2d": {
"op": Op.AVG_POOL2D,
@@ -3383,9 +3398,7 @@ class TosaTestGen:
TosaErrorValidator.evPoolingOutputShapeNonInteger,
TosaErrorValidator.evWrongAccumulatorType,
),
- "data_gen": {
- "fp": (gtu.DataGenType.DOT_PRODUCT,),
- },
+ "data_gen": DOT_PRODUCT_DATAGEN,
},
# Templated operator. Filled in by createDynamicOpLists
"conv2d_TEMPLATE": {
@@ -3416,9 +3429,7 @@ class TosaTestGen:
TosaErrorValidator.evConvOutputShapeNonInteger,
TosaErrorValidator.evWrongAccumulatorType,
),
- "data_gen": {
- "fp": (gtu.DataGenType.DOT_PRODUCT,),
- },
+ "data_gen": DOT_PRODUCT_DATAGEN,
"broadcastable_bias": True,
"filter": KERNELS_2D,
"template": True,
@@ -3452,9 +3463,7 @@ class TosaTestGen:
TosaErrorValidator.evConvOutputShapeNonInteger,
TosaErrorValidator.evWrongAccumulatorType,
),
- "data_gen": {
- "fp": (gtu.DataGenType.DOT_PRODUCT,),
- },
+ "data_gen": DOT_PRODUCT_DATAGEN,
"filter": KERNELS_3D,
"template": True,
},
@@ -3487,9 +3496,7 @@ class TosaTestGen:
TosaErrorValidator.evConvOutputShapeNonInteger,
TosaErrorValidator.evWrongAccumulatorType,
),
- "data_gen": {
- "fp": (gtu.DataGenType.DOT_PRODUCT,),
- },
+ "data_gen": DOT_PRODUCT_DATAGEN,
"filter": KERNELS_2D,
"template": True,
},
@@ -3514,9 +3521,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.DOT_PRODUCT,),
- },
+ "data_gen": DOT_PRODUCT_DATAGEN,
},
"matmul": {
"op": Op.MATMUL,
@@ -3538,9 +3543,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.DOT_PRODUCT,),
- },
+ "data_gen": DOT_PRODUCT_DATAGEN,
},
"max_pool2d": {
"op": Op.MAX_POOL2D,
@@ -3567,9 +3570,7 @@ class TosaTestGen:
TosaErrorValidator.evPoolingOutputShapeMismatch,
TosaErrorValidator.evPoolingOutputShapeNonInteger,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
# Templated operator. Filled in by createDynamicOpLists
"transpose_conv2d_TEMPLATE": {
@@ -3601,9 +3602,7 @@ class TosaTestGen:
TosaErrorValidator.evConvOutputShapeMismatch,
TosaErrorValidator.evWrongAccumulatorType,
),
- "data_gen": {
- "fp": (gtu.DataGenType.DOT_PRODUCT,),
- },
+ "data_gen": DOT_PRODUCT_DATAGEN,
"filter": KERNELS_2D,
"template": True,
},
@@ -3625,9 +3624,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
"sigmoid": {
"op": Op.SIGMOID,
@@ -3645,9 +3642,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
"tanh": {
"op": Op.TANH,
@@ -3665,9 +3660,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
"compliance": {
"abs_error_lower_bound": 0.5,
},
@@ -3688,9 +3681,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
"compliance": {"ulp": 5},
},
# Elementwise Binary Operators
@@ -3713,9 +3704,7 @@ class TosaTestGen:
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
"compliance": {"ulp": 0.5},
},
"arithmetic_right_shift": {
@@ -3937,9 +3926,7 @@ class TosaTestGen:
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
"minimum": {
"op": Op.MINIMUM,
@@ -3960,9 +3947,7 @@ class TosaTestGen:
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
"mul": {
"op": Op.MUL,
@@ -3983,9 +3968,7 @@ class TosaTestGen:
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
"compliance": {"ulp": 0.5},
},
"pow": {
@@ -4007,9 +3990,7 @@ class TosaTestGen:
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
"sub": {
"op": Op.SUB,
@@ -4030,9 +4011,7 @@ class TosaTestGen:
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
"compliance": {"ulp": 0.5},
},
"table": {
@@ -4072,9 +4051,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.FULL_RANGE,),
- },
+ "data_gen": EW_UNARY_DATAGEN,
},
"bitwise_not": {
"op": Op.BITWISE_NOT,
@@ -4109,9 +4086,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.FULL_RANGE,),
- },
+ "data_gen": EW_UNARY_DATAGEN,
"compliance": {"ulp": 0.5},
},
"clz": {
@@ -4147,9 +4122,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
"compliance": {"abs_error_normal_divisor": 2},
},
"exp": {
@@ -4168,9 +4141,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.FULL_RANGE,),
- },
+ "data_gen": EW_UNARY_DATAGEN,
},
"floor": {
"op": Op.FLOOR,
@@ -4188,9 +4159,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.FULL_RANGE,),
- },
+ "data_gen": EW_UNARY_DATAGEN,
"compliance": {"ulp": 0.5},
},
"log": {
@@ -4209,9 +4178,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.FULL_RANGE,),
- },
+ "data_gen": EW_UNARY_DATAGEN,
"compliance": {"ulp": 5},
},
"logical_not": {
@@ -4250,9 +4217,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.FULL_RANGE,),
- },
+ "data_gen": EW_UNARY_DATAGEN,
},
"reciprocal": {
"op": Op.RECIPROCAL,
@@ -4270,9 +4235,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.FULL_RANGE,),
- },
+ "data_gen": EW_UNARY_DATAGEN,
"compliance": {"ulp": 1.0},
},
"rsqrt": {
@@ -4291,9 +4254,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.FULL_RANGE,),
- },
+ "data_gen": EW_UNARY_DATAGEN,
"compliance": {"ulp": 2},
},
"sin": {
@@ -4312,9 +4273,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
"compliance": {"abs_error_normal_divisor": 2},
},
# Elementwise Ternary operators
@@ -4337,9 +4296,7 @@ class TosaTestGen:
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
# Comparison operators
"equal": {
@@ -4361,9 +4318,7 @@ class TosaTestGen:
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
"greater_equal": {
"op": Op.GREATER_EQUAL,
@@ -4384,9 +4339,7 @@ class TosaTestGen:
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
"greater": {
"op": Op.GREATER,
@@ -4407,9 +4360,7 @@ class TosaTestGen:
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
# Reduction operators
"reduce_all": {
@@ -4474,9 +4425,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
"reduce_min": {
"op": Op.REDUCE_MIN,
@@ -4498,9 +4447,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
"reduce_product": {
"op": Op.REDUCE_PRODUCT,
@@ -4522,9 +4469,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
"reduce_sum": {
"op": Op.REDUCE_SUM,
@@ -4546,9 +4491,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.DOT_PRODUCT,),
- },
+ "data_gen": DOT_PRODUCT_DATAGEN,
},
# Data layout operators
"concat": {
@@ -4571,9 +4514,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
"pad": {
"op": Op.PAD,
@@ -4595,9 +4536,7 @@ class TosaTestGen:
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongRank,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
"dim": {
"op": Op.DIM,
@@ -4635,9 +4574,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
"reverse": {
"op": Op.REVERSE,
@@ -4657,9 +4594,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
"slice": {
"op": Op.SLICE,
@@ -4689,9 +4624,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evRankMismatch,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
"tile": {
"op": Op.TILE,
@@ -4712,9 +4645,7 @@ class TosaTestGen:
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongRank,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
"transpose": {
"op": Op.TRANSPOSE,
@@ -4738,9 +4669,7 @@ class TosaTestGen:
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evTensorSizeInputOutputMismatch,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
# Data nodes
"const": {
@@ -4753,9 +4682,7 @@ class TosaTestGen:
TosaArgGen.agNone,
),
"types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
"identity": {
"op": Op.IDENTITY,
@@ -4767,9 +4694,7 @@ class TosaTestGen:
TosaArgGen.agNone,
),
"types": TYPE_FIB + [DType.INT4, DType.INT48],
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
# Scatter/Gather
"gather": {
@@ -4799,9 +4724,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evWrongRank,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
"scatter": {
"op": Op.SCATTER,
@@ -4821,9 +4744,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evWrongRank,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
},
# Image operations
"resize": {
@@ -4859,9 +4780,7 @@ class TosaTestGen:
TosaErrorValidator.evResizeOutputShapeMismatch,
TosaErrorValidator.evResizeOutputShapeNonInteger,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
"compliance": {"relative": 0.006},
},
# Type conversion
@@ -4891,9 +4810,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
- },
+ "data_gen": PSEUDO_RANDOM_DATAGEN,
"compliance": {"ulp": 0.5},
},
"rescale": {
@@ -5010,9 +4927,7 @@ class TosaTestGen:
TosaErrorValidator.evFFTInputShapeMismatch,
TosaErrorValidator.evFFTOutputShapeMismatch,
),
- "data_gen": {
- "fp": (gtu.DataGenType.DOT_PRODUCT,),
- },
+ "data_gen": DOT_PRODUCT_DATAGEN,
},
"rfft2d": {
"op": Op.RFFT2D,
@@ -5035,9 +4950,7 @@ class TosaTestGen:
TosaErrorValidator.evKernelNotPowerOfTwo,
TosaErrorValidator.evFFTOutputShapeMismatch,
),
- "data_gen": {
- "fp": (gtu.DataGenType.DOT_PRODUCT,),
- },
+ "data_gen": DOT_PRODUCT_DATAGEN,
},
# Shape
"add_shape": {