From ad8e1e25e805f6face5fcf0b3906cd06db46e1d7 Mon Sep 17 00:00:00 2001 From: evacha01 Date: Tue, 19 Mar 2024 12:42:17 +0000 Subject: Make Full Range FP16 tests into extra tests Signed-off-by: evacha01 Change-Id: I8c59ecb5a1fb53d0e9bf64333709f9e3cc908b49 --- .../src/generate/generate_full_range.cc | 4 +- scripts/convert2conformance/convert2conformance.py | 4 +- verif/conformance/test_select.py | 21 +- verif/generator/tosa_arg_gen.py | 42 ++-- verif/generator/tosa_test_gen.py | 229 +++++++-------------- 5 files changed, 115 insertions(+), 185 deletions(-) diff --git a/reference_model/src/generate/generate_full_range.cc b/reference_model/src/generate/generate_full_range.cc index d2a89da..6f1deb2 100644 --- a/reference_model/src/generate/generate_full_range.cc +++ b/reference_model/src/generate/generate_full_range.cc @@ -41,7 +41,7 @@ bool generateFullRange(const GenerateConfig& cfg, void* data, size_t size) // Check we support the operator if (cfg.opType == Op::Op_UNKNOWN) { - WARNING("[Generator][PR] Unknown operator."); + WARNING("[Generator][FR] Unknown operator."); return false; } @@ -52,7 +52,7 @@ bool generateFullRange(const GenerateConfig& cfg, void* data, size_t size) return generate(cfg, outData, size); } default: - WARNING("[Generator][PR] Unsupported type."); + WARNING("[Generator][FR] Unsupported type."); return false; } } diff --git a/scripts/convert2conformance/convert2conformance.py b/scripts/convert2conformance/convert2conformance.py index 531dca8..4a006d6 100755 --- a/scripts/convert2conformance/convert2conformance.py +++ b/scripts/convert2conformance/convert2conformance.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright (c) 2021-2023, ARM Limited. +# Copyright (c) 2021-2024, ARM Limited. # SPDX-License-Identifier: Apache-2.0 """This script converts generated tests into conformance tests. @@ -240,7 +240,7 @@ def update_desc_json( # Add tags (if any) if tags is not None: - test_desc["tag"] = tags + test_desc["tag"] = test_desc.get("tag", []) + tags return test_desc diff --git a/verif/conformance/test_select.py b/verif/conformance/test_select.py index 58d3f9f..e3a8ffb 100644 --- a/verif/conformance/test_select.py +++ b/verif/conformance/test_select.py @@ -259,11 +259,11 @@ class Operator: negative and "ERRORIF" in str(path) ): # Check for test set paths - match = re.match(r"(.*)_s([0-9]+)", path.name) + match = re.match(r"(.*)_(s[0-9]+|full)", path.name) if match: - if match.group(2) == "0": + if match.group(2) in ["s0", "full"]: # Only return the truncated test name - # of the first test of a set + # of the first test of a set, and for full tests yield path.with_name(match.group(1)) else: yield path @@ -308,11 +308,21 @@ class Operator: paths.append(set_path) else: if s == 0: - logger.error(f"Could not find test set 0 - {str(set_path)}") + logger.warning(f"Could not find test set 0 - {str(set_path)}") break s += 1 return paths + @staticmethod + def _get_extra_test_paths(path): + """Expand a path to find extra tests.""" + paths = [] + for suffix in ["full"]: + suffix_path = path.with_name(f"{path.name}_{suffix}") + if suffix_path.exists(): + paths.append(suffix_path) + return paths + def select_tests(self): # noqa: C901 (function too complex) """Generate the paths to the selected tests for this operator.""" if not self.test_paths: @@ -374,6 +384,9 @@ class Operator: # Must be a test set - expand to all test sets for p in Operator._get_test_set_paths(path): yield p + # check for extra tests + for p in Operator._get_extra_test_paths(path): + yield p # search for tests that match any unused parameter values for n, path in enumerate(sorted(list(unused_paths))): diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 79d4e78..f9499b5 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1831,10 +1831,10 @@ class TosaArgGen: and "data_gen" in testGen.TOSA_OP_LIST[opName] and gtu.dtypeIsSupportedByCompliance(dtype) ): - if gtu.dtypeIsFloat(dtype): - dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"] - else: - dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"] + dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"].get( + dtype, (gtu.DataGenType.PSEUDO_RANDOM,) + ) + else: # Error test or No data generator types listed - assume random dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,) @@ -1843,16 +1843,7 @@ class TosaArgGen: new_arg_list = [] for dg_type in dataGenTypesList: for arg_str, args_dict in arg_list: - - if dg_type == gtu.DataGenType.FULL_RANGE: - tensor_size = gtu.product(shapeList[0]) - if tensor_size >= gtu.DTYPE_ATTRIBUTES[dtype]["fullset"]: - # Large enough tensor data size for full range, add a single test - num_test_sets = 0 - else: - # Not enough data size for full range of values, revert to random numbers - dg_type = gtu.DataGenType.PSEUDO_RANDOM - + gen_args_dict = args_dict.copy() if dg_type == gtu.DataGenType.PSEUDO_RANDOM: if error_name is None: num_test_sets = ( @@ -1883,18 +1874,31 @@ class TosaArgGen: num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS + elif dg_type == gtu.DataGenType.FULL_RANGE: + tensor_size = gtu.product(shapeList[0]) + if tensor_size < gtu.DTYPE_ATTRIBUTES[dtype]["fullset"]: + shape_info = " ({})".format(shapeList[0]) + logger.info( + f"Skipping {opName}{shape_info} as tensor data size too small for full range of values {tensor_size} < {gtu.DTYPE_ATTRIBUTES[dtype]['fullset']}" + ) + continue + # Large enough tensor data size for full range, add a single test + num_test_sets = 0 + arg_str = f"{arg_str}_full" if arg_str else "full" + gen_args_dict["tags"] = args_dict.get("tags", []) + [ + "non_finite_fp_data" + ] + + gen_args_dict["dg_type"] = dg_type if num_test_sets > 0: for s in range(0, num_test_sets): set_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}" - set_args_dict = args_dict.copy() + set_args_dict = gen_args_dict.copy() set_args_dict["s"] = s - set_args_dict["dg_type"] = dg_type new_arg_list.append((set_arg_str, set_args_dict)) else: # Default is a single test - new_args_dict = args_dict.copy() - new_args_dict["dg_type"] = dg_type - new_arg_list.append((arg_str, new_args_dict)) + new_arg_list.append((arg_str, gen_args_dict)) return new_arg_list diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 71d7fcc..399fed6 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -110,7 +110,7 @@ class TosaTestGen: def getSerializer(self): return self.ser - def serialize(self, testName, metaData=None): + def serialize(self, testName, metaData=None, tags=None): path = Path(self.basePath) / self.testPath # Write out TOSA flatbuffer binary @@ -125,6 +125,9 @@ class TosaTestGen: # Add extra meta data to desc.json desc["meta"] = metaData + if tags: + desc["tag"] = tags + # Validate desc.json before we output it self.descSchemaValidator.validate_config(desc) @@ -3146,6 +3149,8 @@ class TosaTestGen: tensMeta["data_gen"] = tvgInfo.dataGenDict tens = tvgInfo.tensorList + tags = argsDict.get("tags", None) + result = build_fcn( self, build_rng, @@ -3164,7 +3169,7 @@ class TosaTestGen: compliance = result.getComplianceInfo() if compliance: tensMeta["compliance"] = compliance - self.serialize("test", tensMeta) + self.serialize("test", tensMeta, tags) return True else: # The test is not valid @@ -3326,6 +3331,18 @@ class TosaTestGen: KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]] KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]] + PSEUDO_RANDOM_DATAGEN = { + DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM,), + DType.FP32: (gtu.DataGenType.PSEUDO_RANDOM,), + } + DOT_PRODUCT_DATAGEN = { + DType.FP16: (gtu.DataGenType.DOT_PRODUCT,), + DType.FP32: (gtu.DataGenType.DOT_PRODUCT,), + } + EW_UNARY_DATAGEN = { + DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FULL_RANGE) + } + TOSA_OP_LIST = { # Tensor operators "argmax": { @@ -3350,9 +3367,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, "avg_pool2d": { "op": Op.AVG_POOL2D, @@ -3383,9 +3398,7 @@ class TosaTestGen: TosaErrorValidator.evPoolingOutputShapeNonInteger, TosaErrorValidator.evWrongAccumulatorType, ), - "data_gen": { - "fp": (gtu.DataGenType.DOT_PRODUCT,), - }, + "data_gen": DOT_PRODUCT_DATAGEN, }, # Templated operator. Filled in by createDynamicOpLists "conv2d_TEMPLATE": { @@ -3416,9 +3429,7 @@ class TosaTestGen: TosaErrorValidator.evConvOutputShapeNonInteger, TosaErrorValidator.evWrongAccumulatorType, ), - "data_gen": { - "fp": (gtu.DataGenType.DOT_PRODUCT,), - }, + "data_gen": DOT_PRODUCT_DATAGEN, "broadcastable_bias": True, "filter": KERNELS_2D, "template": True, @@ -3452,9 +3463,7 @@ class TosaTestGen: TosaErrorValidator.evConvOutputShapeNonInteger, TosaErrorValidator.evWrongAccumulatorType, ), - "data_gen": { - "fp": (gtu.DataGenType.DOT_PRODUCT,), - }, + "data_gen": DOT_PRODUCT_DATAGEN, "filter": KERNELS_3D, "template": True, }, @@ -3487,9 +3496,7 @@ class TosaTestGen: TosaErrorValidator.evConvOutputShapeNonInteger, TosaErrorValidator.evWrongAccumulatorType, ), - "data_gen": { - "fp": (gtu.DataGenType.DOT_PRODUCT,), - }, + "data_gen": DOT_PRODUCT_DATAGEN, "filter": KERNELS_2D, "template": True, }, @@ -3514,9 +3521,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.DOT_PRODUCT,), - }, + "data_gen": DOT_PRODUCT_DATAGEN, }, "matmul": { "op": Op.MATMUL, @@ -3538,9 +3543,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.DOT_PRODUCT,), - }, + "data_gen": DOT_PRODUCT_DATAGEN, }, "max_pool2d": { "op": Op.MAX_POOL2D, @@ -3567,9 +3570,7 @@ class TosaTestGen: TosaErrorValidator.evPoolingOutputShapeMismatch, TosaErrorValidator.evPoolingOutputShapeNonInteger, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, # Templated operator. Filled in by createDynamicOpLists "transpose_conv2d_TEMPLATE": { @@ -3601,9 +3602,7 @@ class TosaTestGen: TosaErrorValidator.evConvOutputShapeMismatch, TosaErrorValidator.evWrongAccumulatorType, ), - "data_gen": { - "fp": (gtu.DataGenType.DOT_PRODUCT,), - }, + "data_gen": DOT_PRODUCT_DATAGEN, "filter": KERNELS_2D, "template": True, }, @@ -3625,9 +3624,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, "sigmoid": { "op": Op.SIGMOID, @@ -3645,9 +3642,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, "tanh": { "op": Op.TANH, @@ -3665,9 +3660,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, "compliance": { "abs_error_lower_bound": 0.5, }, @@ -3688,9 +3681,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, "compliance": {"ulp": 5}, }, # Elementwise Binary Operators @@ -3713,9 +3704,7 @@ class TosaTestGen: TosaErrorValidator.evDimensionMismatch, TosaErrorValidator.evBroadcastShapesMismatch, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, "compliance": {"ulp": 0.5}, }, "arithmetic_right_shift": { @@ -3937,9 +3926,7 @@ class TosaTestGen: TosaErrorValidator.evDimensionMismatch, TosaErrorValidator.evBroadcastShapesMismatch, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, "minimum": { "op": Op.MINIMUM, @@ -3960,9 +3947,7 @@ class TosaTestGen: TosaErrorValidator.evDimensionMismatch, TosaErrorValidator.evBroadcastShapesMismatch, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, "mul": { "op": Op.MUL, @@ -3983,9 +3968,7 @@ class TosaTestGen: TosaErrorValidator.evDimensionMismatch, TosaErrorValidator.evBroadcastShapesMismatch, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, "compliance": {"ulp": 0.5}, }, "pow": { @@ -4007,9 +3990,7 @@ class TosaTestGen: TosaErrorValidator.evDimensionMismatch, TosaErrorValidator.evBroadcastShapesMismatch, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, "sub": { "op": Op.SUB, @@ -4030,9 +4011,7 @@ class TosaTestGen: TosaErrorValidator.evDimensionMismatch, TosaErrorValidator.evBroadcastShapesMismatch, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, "compliance": {"ulp": 0.5}, }, "table": { @@ -4072,9 +4051,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.FULL_RANGE,), - }, + "data_gen": EW_UNARY_DATAGEN, }, "bitwise_not": { "op": Op.BITWISE_NOT, @@ -4109,9 +4086,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.FULL_RANGE,), - }, + "data_gen": EW_UNARY_DATAGEN, "compliance": {"ulp": 0.5}, }, "clz": { @@ -4147,9 +4122,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, "compliance": {"abs_error_normal_divisor": 2}, }, "exp": { @@ -4168,9 +4141,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.FULL_RANGE,), - }, + "data_gen": EW_UNARY_DATAGEN, }, "floor": { "op": Op.FLOOR, @@ -4188,9 +4159,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.FULL_RANGE,), - }, + "data_gen": EW_UNARY_DATAGEN, "compliance": {"ulp": 0.5}, }, "log": { @@ -4209,9 +4178,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.FULL_RANGE,), - }, + "data_gen": EW_UNARY_DATAGEN, "compliance": {"ulp": 5}, }, "logical_not": { @@ -4250,9 +4217,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.FULL_RANGE,), - }, + "data_gen": EW_UNARY_DATAGEN, }, "reciprocal": { "op": Op.RECIPROCAL, @@ -4270,9 +4235,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.FULL_RANGE,), - }, + "data_gen": EW_UNARY_DATAGEN, "compliance": {"ulp": 1.0}, }, "rsqrt": { @@ -4291,9 +4254,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.FULL_RANGE,), - }, + "data_gen": EW_UNARY_DATAGEN, "compliance": {"ulp": 2}, }, "sin": { @@ -4312,9 +4273,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, "compliance": {"abs_error_normal_divisor": 2}, }, # Elementwise Ternary operators @@ -4337,9 +4296,7 @@ class TosaTestGen: TosaErrorValidator.evDimensionMismatch, TosaErrorValidator.evBroadcastShapesMismatch, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, # Comparison operators "equal": { @@ -4361,9 +4318,7 @@ class TosaTestGen: TosaErrorValidator.evDimensionMismatch, TosaErrorValidator.evBroadcastShapesMismatch, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, "greater_equal": { "op": Op.GREATER_EQUAL, @@ -4384,9 +4339,7 @@ class TosaTestGen: TosaErrorValidator.evDimensionMismatch, TosaErrorValidator.evBroadcastShapesMismatch, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, "greater": { "op": Op.GREATER, @@ -4407,9 +4360,7 @@ class TosaTestGen: TosaErrorValidator.evDimensionMismatch, TosaErrorValidator.evBroadcastShapesMismatch, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, # Reduction operators "reduce_all": { @@ -4474,9 +4425,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, "reduce_min": { "op": Op.REDUCE_MIN, @@ -4498,9 +4447,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, "reduce_product": { "op": Op.REDUCE_PRODUCT, @@ -4522,9 +4469,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, "reduce_sum": { "op": Op.REDUCE_SUM, @@ -4546,9 +4491,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.DOT_PRODUCT,), - }, + "data_gen": DOT_PRODUCT_DATAGEN, }, # Data layout operators "concat": { @@ -4571,9 +4514,7 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, "pad": { "op": Op.PAD, @@ -4595,9 +4536,7 @@ class TosaTestGen: TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongRank, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, "dim": { "op": Op.DIM, @@ -4635,9 +4574,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, "reverse": { "op": Op.REVERSE, @@ -4657,9 +4594,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, "slice": { "op": Op.SLICE, @@ -4689,9 +4624,7 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evRankMismatch, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, "tile": { "op": Op.TILE, @@ -4712,9 +4645,7 @@ class TosaTestGen: TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongRank, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, "transpose": { "op": Op.TRANSPOSE, @@ -4738,9 +4669,7 @@ class TosaTestGen: TosaErrorValidator.evRankMismatch, TosaErrorValidator.evTensorSizeInputOutputMismatch, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, # Data nodes "const": { @@ -4753,9 +4682,7 @@ class TosaTestGen: TosaArgGen.agNone, ), "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2], - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, "identity": { "op": Op.IDENTITY, @@ -4767,9 +4694,7 @@ class TosaTestGen: TosaArgGen.agNone, ), "types": TYPE_FIB + [DType.INT4, DType.INT48], - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, # Scatter/Gather "gather": { @@ -4799,9 +4724,7 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evWrongRank, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, "scatter": { "op": Op.SCATTER, @@ -4821,9 +4744,7 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evWrongRank, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, }, # Image operations "resize": { @@ -4859,9 +4780,7 @@ class TosaTestGen: TosaErrorValidator.evResizeOutputShapeMismatch, TosaErrorValidator.evResizeOutputShapeNonInteger, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, "compliance": {"relative": 0.006}, }, # Type conversion @@ -4891,9 +4810,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), - "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), - }, + "data_gen": PSEUDO_RANDOM_DATAGEN, "compliance": {"ulp": 0.5}, }, "rescale": { @@ -5010,9 +4927,7 @@ class TosaTestGen: TosaErrorValidator.evFFTInputShapeMismatch, TosaErrorValidator.evFFTOutputShapeMismatch, ), - "data_gen": { - "fp": (gtu.DataGenType.DOT_PRODUCT,), - }, + "data_gen": DOT_PRODUCT_DATAGEN, }, "rfft2d": { "op": Op.RFFT2D, @@ -5035,9 +4950,7 @@ class TosaTestGen: TosaErrorValidator.evKernelNotPowerOfTwo, TosaErrorValidator.evFFTOutputShapeMismatch, ), - "data_gen": { - "fp": (gtu.DataGenType.DOT_PRODUCT,), - }, + "data_gen": DOT_PRODUCT_DATAGEN, }, # Shape "add_shape": { -- cgit v1.2.1