diff options
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 351 |
1 files changed, 229 insertions, 122 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 3014c81..d15f785 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -1,8 +1,12 @@ # Copyright (c) 2020-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 +import json import os from copy import deepcopy +from datetime import datetime +from pathlib import Path +import generator.tosa_utils as gtu import numpy as np import serializer.tosa_serializer as ts from generator.tosa_arg_gen import TosaArgGen @@ -13,15 +17,15 @@ from generator.tosa_error_if import ErrorIf from generator.tosa_error_if import TosaErrorIfArgGen from generator.tosa_error_if import TosaErrorValidator from generator.tosa_error_if import TosaInvalidValidator -from generator.tosa_utils import DTYPE_ATTRIBUTES -from generator.tosa_utils import get_rank_mismatch_shape -from generator.tosa_utils import get_wrong_output_type -from generator.tosa_utils import MAX_RESIZE_DIMENSION -from generator.tosa_utils import usableDTypes -from generator.tosa_utils import vect_f32_to_bf16 +from schemavalidation.schemavalidation import TestDescSchemaValidator from tosa.DType import DType from tosa.Op import Op +TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited +// SPDX-License-Identifier: Apache-2.0 +// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests +""" + class TosaTestGen: # Maximum rank of tensor supported by test generator. @@ -31,6 +35,10 @@ class TosaTestGen: TOSA_8K_LEVEL_MAX_KERNEL = 8192 TOSA_8K_LEVEL_MAX_STRIDE = 8192 + # Main compliance dot product statistical test range + TOSA_MI_DOT_PRODUCT_TEST_SETS = range(0, 6) + TOSA_MI_DOT_PRODUCT_MIN = 1000 + def __init__(self, args): self.args = args self.basePath = args.output_dir @@ -45,6 +53,8 @@ class TosaTestGen: # Work out floating point range self.random_fp_low = min(args.tensor_fp_value_range) self.random_fp_high = max(args.tensor_fp_value_range) + # JSON schema validation + self.descSchemaValidator = TestDescSchemaValidator() def createSerializer(self, opName, testPath): self.testPath = os.path.join(opName, testPath) @@ -53,81 +63,131 @@ class TosaTestGen: os.makedirs(fullPath, exist_ok=True) # Embed const data in the flatbuffer constMode = ts.ConstMode.EMBED - if self.args.dump_consts: + if self.args.lazy_data_gen: + # Lazy data generation - so make constants files + constMode = ts.ConstMode.INPUTS + elif self.args.dump_consts: constMode = ts.ConstMode.EMBED_DUMP self.ser = ts.TosaSerializer(fullPath, constMode) def getSerializer(self): return self.ser - def serialize(self, testName): - with open( - os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb" - ) as fd: + def serialize(self, testName, metaData=None): + path = Path(self.basePath) / self.testPath + + # Write out TOSA flatbuffer binary + path_fb = path / f"{testName}.tosa" + with path_fb.open("wb") as fd: fd.write(self.ser.serialize()) - with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd: - fd.write(self.ser.writeJson("{}.tosa".format(testName))) + # Get JSON descriptor from serializer + desc = json.loads(self.ser.writeJson(f"{testName}.tosa")) + + if metaData: + # Add extra meta data to desc.json + desc["meta"] = metaData + + # Validate desc.json before we output it + self.descSchemaValidator.validate_config(desc) + + if metaData: + if self.args.lazy_data_gen and "data_gen" in metaData: + # Output datagen meta data as CPP data + path_md = path / f"{testName}_meta_data_gen.cpp" + with path_md.open("w") as fd: + fd.write(TOSA_AUTOGENERATED_HEADER) + fd.write("// Test meta data for data generation setup\n\n") + fd.write(f'const char* json_tdg_config_{path.stem} = R"(') + json.dump(metaData["data_gen"], fd) + fd.write(')";\n\n') + if "compliance" in metaData: + # Output datagen meta data as CPP data + path_md = path / f"{testName}_meta_compliance.cpp" + with path_md.open("w") as fd: + fd.write(TOSA_AUTOGENERATED_HEADER) + fd.write("// Test meta data for compliance validation\n\n") + fd.write(f'const char* json_tvf_config_{path.stem} = R"(') + json.dump(metaData["compliance"], fd) + fd.write(')";\n\n') + + # Write desc.json + path_desc = path / "desc.json" + with path_desc.open("w") as fd: + json.dump(desc, fd, indent=1) def resetRNG(self, seed=None): if seed is None: seed = self.random_seed + 1 self.rng = np.random.default_rng(seed) - def getRandTensor(self, shape, dtype): - if dtype == DType.BOOL: - return np.bool_(self.rng.choice(a=[False, True], size=shape)) - # TOSA specific INT4 weight range from -7 to 7 + def getDTypeRange(self, dtype, high_inclusive=False): + # Returns dtype value range boundaries (low, high) + # The high boundary is excluded in the range + # unless high_inclusive is True + + if dtype in (DType.FP32, DType.FP16, DType.BF16): + return (self.random_fp_low, self.random_fp_high) + elif dtype == DType.BOOL: + rng = (0, 2) + elif dtype == DType.UINT8: + rng = (0, 256) + elif dtype == DType.UINT16: + rng = (0, 65536) elif dtype == DType.INT4: - return np.int32(self.rng.integers(low=-7, high=8, size=shape)) + # TOSA specific INT4 weight range from -7 to 7 + rng = (-7, 8) elif dtype == DType.INT8: - return np.int32(self.rng.integers(low=-128, high=128, size=shape)) - elif dtype == DType.UINT8: - return np.int32(self.rng.integers(low=0, high=256, size=shape)) + rng = (-128, 128) elif dtype == DType.INT16: - return np.int32(self.rng.integers(low=-32768, high=32768, size=shape)) - elif dtype == DType.UINT16: - return np.int32(self.rng.integers(low=0, high=65536, size=shape)) - elif ( - dtype == DType.INT32 or dtype == DType.SHAPE - ): # restricting too large value for SHAPE - return np.int32( - self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape) - ) + rng = (-32768, 32768) + elif dtype in (DType.INT32, DType.SHAPE): + # restricting too large value for SHAPE + rng = (-(1 << 31), (1 << 31)) elif dtype == DType.INT48: - return np.int64( - self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape) - ) - elif dtype == DType.FP16: - return np.float16( - self.rng.uniform( - low=self.random_fp_low, high=self.random_fp_high, size=shape - ) - ) - elif dtype == DType.BF16: - f32_tensor = np.float32( - self.rng.uniform( - low=self.random_fp_low, high=self.random_fp_high, size=shape - ) - ) - # Floor the last 16 bits of each f32 value - return np.float32(vect_f32_to_bf16(f32_tensor)) - elif dtype == DType.FP32: - return np.float32( - self.rng.uniform( - low=self.random_fp_low, high=self.random_fp_high, size=shape - ) - ) + rng = (-(1 << 47), (1 << 47)) + else: + raise Exception("Unknown dtype: {}".format(dtype)) + + if not high_inclusive: + # Exclusive high: low <= range < high + return rng else: - raise Exception("Unrecognized Dtype: {}".format(dtype)) + # Inclusive range: low <= range <= high + return (rng[0], rng[1] - 1) + + def getRandTensor(self, shape, dtype): + low, high = self.getDTypeRange(dtype) + + if dtype == DType.BOOL: + return np.bool_(self.rng.choice(a=[False, True], size=shape)) + elif dtype == DType.INT48: + return np.int64(self.rng.integers(low=low, high=high, size=shape)) + elif dtype in (DType.FP16, DType.BF16, DType.FP32): + f_tensor = self.rng.uniform(low=low, high=high, size=shape) + + if dtype == DType.FP16: + return np.float16(f_tensor) + else: + f32_tensor = np.float32(f_tensor) + if dtype == DType.BF16: + # Floor the last 16 bits of each f32 value + return np.float32(gtu.vect_f32_to_bf16(f32_tensor)) + else: + return f32_tensor + else: + # All other integer types + return np.int32(self.rng.integers(low=low, high=high, size=shape)) def buildPlaceholderTensors(self, shape_list, dtype_list): placeholders = [] assert len(shape_list) == len(dtype_list) + arr = None for idx, shape in enumerate(shape_list): - arr = self.getRandTensor(shape, dtype_list[idx]) + if not self.args.lazy_data_gen: + arr = self.getRandTensor(shape, dtype_list[idx]) placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr)) return placeholders @@ -137,8 +197,10 @@ class TosaTestGen: assert len(shape_list) == len(dtype_list) + arr = None for idx, shape in enumerate(shape_list): - arr = self.getRandTensor(shape, dtype_list[idx]) + if not self.args.lazy_data_gen: + arr = self.getRandTensor(shape, dtype_list[idx]) consts.append(self.ser.addConst(shape, dtype_list[idx], arr)) return consts @@ -161,38 +223,20 @@ class TosaTestGen: return np.int32(self.rng.integers(low=low, high=high, size=1))[0] def getRandNumberDType(self, dtype): + low, high = self.getDTypeRange(dtype) + if dtype == DType.FP32: - return np.float32( - self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high) - ) + return np.float32(self.rng.uniform(low=low, high=high)) elif dtype == DType.FP16: - return np.float16( - self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high) - ) + return np.float16(self.rng.uniform(low=low, high=high)) elif dtype == DType.BF16: - rand_f32 = np.float32( - self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high) - ) - return vect_f32_to_bf16(rand_f32) + rand_f32 = np.float32(self.rng.uniform(low=low, high=high)) + return gtu.vect_f32_to_bf16(rand_f32) elif dtype == DType.BOOL: return self.rng.choice([False, True]) - # TOSA specific INT4 weight range from -7 to 7 - elif dtype == DType.INT4: - low, high = (-7, 8) - elif dtype == DType.INT8: - low, high = (-128, 128) - elif dtype == DType.INT16: - low, high = (-32768, 32768) - elif ( - dtype == DType.INT32 or dtype == DType.SHAPE - ): # restricting too large value for SHAPE - low, high = (-(1 << 31), (1 << 31)) elif dtype == DType.INT48: - low, high = (-(1 << 47), (1 << 47)) # Special size return np.int64(self.rng.integers(low, high, size=1))[0] - else: - raise Exception("Unknown dtype: {}".format(dtype)) return np.int32(self.rng.integers(low, high, size=1))[0] @@ -212,8 +256,8 @@ class TosaTestGen: # Limit types to the first 2 as the 3rd is the accumulator return "x".join(strs[:2]) else: - if dtype in DTYPE_ATTRIBUTES: - return DTYPE_ATTRIBUTES[dtype]["str"] + if dtype in gtu.DTYPE_ATTRIBUTES: + return gtu.DTYPE_ATTRIBUTES[dtype]["str"] else: raise Exception( "Unknown dtype, cannot convert to string: {}".format(dtype) @@ -221,8 +265,8 @@ class TosaTestGen: def typeWidth(self, dtype): """Get the datatype width for data types""" - if dtype in DTYPE_ATTRIBUTES: - return DTYPE_ATTRIBUTES[dtype]["width"] + if dtype in gtu.DTYPE_ATTRIBUTES: + return gtu.DTYPE_ATTRIBUTES[dtype]["width"] else: raise Exception(f"Unknown dtype, cannot determine width: {dtype}") @@ -237,11 +281,44 @@ class TosaTestGen: low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1] ) - # Argument generators - # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list]) - # Where the string descriptor is used to generate the test name and - # The build_fcn_arg_list is expanded and passed to the operator test - # build function + def tensorComplianceMetaData(self, op, argsDict, outputTensor, errorName): + if errorName: + # No compliance for error tests + return None + # Create compliance meta data for expected output tensor + compliance_tens = {"mode": None} + if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT: + mode = gtu.ComplianceMode.DOT_PRODUCT + compliance_tens["dot_product_info"] = { + "s": argsDict["s"], + "ks": argsDict["ks"], + "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"], + } + elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL: + mode = gtu.ComplianceMode.FP_SPECIAL + elif "compliance" in op and "ulp" in op["compliance"]: + mode = gtu.ComplianceMode.ULP + compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]} + elif op["op"] == Op.REDUCE_PRODUCT: + mode = gtu.ComplianceMode.REDUCE_PRODUCT + else: + mode = gtu.ComplianceMode.EXACT + compliance_tens["mode"] = gtu.ComplianceMode(mode).name + + return compliance_tens + + # Build Op functions + # Create the output tensor (calling OutputShaper as needed) + # Do final tweaks to attributes (if necessary for errorIf) + # Add Op into graph + # Return resulting tensor information or BuildInfo + + class BuildInfo: + """Enhanced build information containing result tensor and associated compliance dict.""" + + def __init__(self, resultTensor, complianceDict): + self.resultTensor = resultTensor + self.complianceDict = complianceDict def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None): result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) @@ -975,15 +1052,16 @@ class TosaTestGen: return result_tens def build_matmul( - self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None + self, op, a, b, args_dict, validator_fcns=None, error_name=None, qinfo=None ): - result_tens = OutputShaper.matmulOp( + accum_dtype = args_dict["acc_type"] + result_tensor = OutputShaper.matmulOp( self.ser, self.rng, a, b, accum_dtype, error_name ) # Invalidate Input/Output list for error if checks. input_list = [a.name, b.name] - output_list = [result_tens.name] + output_list = [result_tensor.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( @@ -999,10 +1077,10 @@ class TosaTestGen: input_dtype=a.dtype, input2_shape=b.shape, input2_dtype=b.dtype, - output_shape=result_tens.shape, - output_dtype=result_tens.dtype, + output_shape=result_tensor.shape, + output_dtype=result_tensor.dtype, qinfo=qinfo, - result_tensors=[result_tens], + result_tensors=[result_tensor], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1014,7 +1092,12 @@ class TosaTestGen: attr.MatMulAttribute(qinfo[0], qinfo[1]) self.ser.addOperator(op["op"], input_list, output_list, attr) - return result_tens + + compliance = self.tensorComplianceMetaData( + op, args_dict, result_tensor, error_name + ) + + return TosaTestGen.BuildInfo(result_tensor, compliance) def build_reduce(self, op, a, axis, validator_fcns, error_name=None): result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name) @@ -1895,7 +1978,7 @@ class TosaTestGen: def _get_condition_tensor(self, op, cond, error_name): if error_name == ErrorIf.CondIfCondNotMatchingBool: - cond_type = get_wrong_output_type(op, self.rng, DType.BOOL) + cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL) else: cond_type = DType.BOOL if error_name == ErrorIf.CondIfCondShapeNotSizeOne: @@ -2357,7 +2440,7 @@ class TosaTestGen: # Initialize a new random number generator self.rng = np.random.default_rng(self.random_seed) - build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"] + _, tgen_fcn, _, agen_fcn = op["build_fcn"] # Test list consists of a tuple of: # (opName, testNameStr, dtype, shapeList, argumentsList) @@ -2461,7 +2544,7 @@ class TosaTestGen: # Create a serializer self.createSerializer(opName, testStr) - build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"] + build_fcn, _, tvgen_fcn, _ = op["build_fcn"] if "error_if_validators" in op: error_if_validators = op["error_if_validators"] else: @@ -2495,24 +2578,37 @@ class TosaTestGen: qgen = None # Build the random tensor operands and the test - tens = [] if qgen is not None: qinfo = qgen(self, op, dtype_or_dtypeList, error_name) else: qinfo = None - tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name) + # Extra meta data for the desc.json + tensMeta = {} + + # Check we are using the new testArgs interface with an argsDict dictionary + if len(testArgs) == 1 and isinstance(testArgs[0], dict): + argsDict = testArgs[0] + assert "dg_type" in argsDict + tvgInfo = tvgen_fcn( + self, opName, dtypeList, shapeList, argsDict, error_name + ) + if tvgInfo.dataGenDict: + tensMeta["data_gen"] = tvgInfo.dataGenDict + tens = tvgInfo.tensorList + else: + tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name) try: if error_if_validators is None: if qinfo is not None: - resultName = build_fcn(self, op, *tens, *testArgs, qinfo) + result = build_fcn(self, op, *tens, *testArgs, qinfo) else: - resultName = build_fcn(self, op, *tens, *testArgs) + result = build_fcn(self, op, *tens, *testArgs) else: if qinfo is not None: - resultName = build_fcn( + result = build_fcn( self, op, *tens, @@ -2522,7 +2618,7 @@ class TosaTestGen: qinfo=qinfo, ) else: - resultName = build_fcn( + result = build_fcn( self, op, *tens, @@ -2534,9 +2630,16 @@ class TosaTestGen: print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n") raise e - if resultName: + if result: # The test is valid, serialize it - self.serialize("test") + if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict: + # Add the compliance meta data + # NOTE: This currently expects only one result output + tensMeta["compliance"] = { + "version": "0.1", + "tensors": {result.resultTensor.name: result.complianceDict}, + } + self.serialize("test", tensMeta) else: # The test is not valid print(f"Invalid ERROR_IF test created: {opName} {testStr}") @@ -2865,7 +2968,7 @@ class TosaTestGen: "build_fcn": ( build_matmul, TosaTensorGen.tgMatmul, - TosaTensorValuesGen.tvgDefault, + TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agMatMul, ), "qgen": TosaQuantGen.qgMatmul, @@ -2878,6 +2981,10 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), + "data_gen": { + "fp": (gtu.DataGenType.DOT_PRODUCT,), + "int": (gtu.DataGenType.PSEUDO_RANDOM,), + }, }, "max_pool2d": { "op": Op.MAX_POOL2D, @@ -4446,7 +4553,7 @@ class OutputShaper: excludes = [DType.FP16, DType.FP32] else: excludes = [out_dtype] - wrong_dtypes = list(usableDTypes(excludes=excludes)) + wrong_dtypes = list(gtu.usableDTypes(excludes=excludes)) out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(ofm_shape, out_dtype) @@ -4508,7 +4615,7 @@ class OutputShaper: excludes = [DType.FP16, DType.FP32] else: excludes = [out_dtype] - wrong_dtypes = list(usableDTypes(excludes=excludes)) + wrong_dtypes = list(gtu.usableDTypes(excludes=excludes)) out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(ofm_shape, out_dtype) @@ -4559,7 +4666,7 @@ class OutputShaper: excludes = [DType.FP16, DType.FP32] else: excludes = [out_dtype] - wrong_dtypes = list(usableDTypes(excludes=excludes)) + wrong_dtypes = list(gtu.usableDTypes(excludes=excludes)) out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(ofm_shape, out_dtype) @@ -4711,7 +4818,7 @@ class OutputShaper: bad_dim = rng.choice(range(len(output_shape))) output_shape[bad_dim] -= rng.choice([1, 2]) elif error_name == ErrorIf.RankMismatch: - output_shape = get_rank_mismatch_shape(rng, output_shape) + output_shape = gtu.get_rank_mismatch_shape(rng, output_shape) if error_name == ErrorIf.WrongOutputType: all_dtypes = [ @@ -4806,7 +4913,7 @@ class OutputShaper: elif error_name == ErrorIf.InputSizeStartLengthMismatch: output_shape = input.shape.copy() elif error_name == ErrorIf.RankMismatch: - output_shape = get_rank_mismatch_shape(rng, output_shape) + output_shape = gtu.get_rank_mismatch_shape(rng, output_shape) return ser.addOutput(output_shape, outputDType) @@ -4820,7 +4927,7 @@ class OutputShaper: output_shape[i] = a.shape[i] * multiples[i] if error_name == ErrorIf.RankMismatch: - output_shape = get_rank_mismatch_shape(rng, output_shape) + output_shape = gtu.get_rank_mismatch_shape(rng, output_shape) if error_name == ErrorIf.WrongOutputType: all_dtypes = [ @@ -4853,7 +4960,7 @@ class OutputShaper: for i in range(len(output_shape)): output_shape[i] += rng.integers(1, 10) elif error_name == ErrorIf.RankMismatch: - output_shape = get_rank_mismatch_shape(rng, output_shape) + output_shape = gtu.get_rank_mismatch_shape(rng, output_shape) if error_name == ErrorIf.WrongOutputType: all_dtypes = [ @@ -4980,21 +5087,21 @@ class OutputShaper: oh = max(oh, 1) ow = max(ow, 1) if error_name != ErrorIf.MaxDimExceeded: - oh = min(oh, MAX_RESIZE_DIMENSION - 1) - ow = min(ow, MAX_RESIZE_DIMENSION - 1) + oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1) + ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1) if error_name == ErrorIf.ResizeOutputShapeMismatch: choices = [1, 2, 3] change = rng.choice(choices) # increment in multiples of scale_y/x_d so we don't hit non-integer error case if change in [1, 3]: - if oh + scale_y_d >= MAX_RESIZE_DIMENSION: + if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION: oh -= scale_y_d assert oh > 0 # Should have been caught in agResize else: oh += scale_y_d if change in [2, 3]: - if ow + scale_x_d >= MAX_RESIZE_DIMENSION: + if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION: ow -= scale_x_d assert ow > 0 # Should have been caught in agResize else: @@ -5051,7 +5158,7 @@ class OutputShaper: excludes = [DType.FP16, DType.FP32] else: excludes = [out_dtype] - wrong_dtypes = list(usableDTypes(excludes=excludes)) + wrong_dtypes = list(gtu.usableDTypes(excludes=excludes)) out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(output_shape, out_dtype) @@ -5075,7 +5182,7 @@ class OutputShaper: if error_name == ErrorIf.WrongOutputType: excludes = [DType.FP32] - wrong_dtypes = list(usableDTypes(excludes=excludes)) + wrong_dtypes = list(gtu.usableDTypes(excludes=excludes)) output_dtype = rng.choice(wrong_dtypes) elif error_name == ErrorIf.BatchMismatch: output_shape[0] += rng.integers(1, 10) @@ -5100,7 +5207,7 @@ class OutputShaper: output_dtype = value.dtype if error_name == ErrorIf.WrongOutputType: excludes = [DType.FP32] - wrong_dtypes = list(usableDTypes(excludes=excludes)) + wrong_dtypes = list(gtu.usableDTypes(excludes=excludes)) output_dtype = rng.choice(wrong_dtypes) elif error_name == ErrorIf.BatchMismatch: output_shape[0] += rng.integers(1, 10) |