diff options
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 174 |
1 files changed, 160 insertions, 14 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 97ff237..8d96090 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -4,12 +4,10 @@ import itertools import math import warnings +import generator.tosa_utils as gtu import numpy as np from generator.tosa_error_if import ErrorIf from generator.tosa_error_if import TosaErrorIfArgGen -from generator.tosa_utils import get_accum_dtype_from_tgTypes -from generator.tosa_utils import get_wrong_output_type -from generator.tosa_utils import MAX_RESIZE_DIMENSION from serializer.tosa_serializer import DTypeNames from tosa.DType import DType from tosa.Op import Op @@ -606,11 +604,18 @@ class TosaTensorGen: class TosaTensorValuesGen: - """Tensor Value generators create the random data for each test.""" + """Tensor Value generators create the random data for each tensor in each test.""" def __init__(self): pass + class TVGInfo: + """Enhanced tensor values information including data gen dict.""" + + def __init__(self, tensorList, dataGenDict): + self.tensorList = tensorList + self.dataGenDict = dataGenDict + @staticmethod def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None): pCount, cCount = op["operands"] @@ -624,6 +629,87 @@ class TosaTensorValuesGen: return tens @staticmethod + def tvgLazyGenDefault( + testGen, opName, dtypeList, shapeList, argsDict, error_name=None + ): + # Variable inputs versus constants + pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"] + + overrideLazy = False + if not gtu.dtypeIsFloat(dtypeList[0]) and testGen.args.lazy_data_gen: + # TEMPORARY OVERRIDE for integer types + overrideLazy = True + testGen.args.lazy_data_gen = False + + # TODO - Change to generation of data using library! + # For now - we fall back to original path (or when dealing with non-floats) + if not testGen.args.lazy_data_gen: + tens_ser_list = TosaTensorValuesGen.tvgDefault( + testGen, + testGen.TOSA_OP_LIST[opName], + dtypeList, + shapeList, + [], + error_name, + ) + if overrideLazy: + # Return to lazy mode + testGen.args.lazy_data_gen = True + return TosaTensorValuesGen.TVGInfo(tens_ser_list, None) + + # Create data generator meta-data + dg_type = argsDict["dg_type"] + dg_tens_meta = {} + tens_ser_list = [] + for idx, shape in enumerate(shapeList): + + tens_meta = {} + tens_meta["generator"] = gtu.DataGenType(dg_type).name + tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"] + tens_meta["shape"] = [int(i) for i in shape] + tens_meta["input_pos"] = idx + tens_meta["op"] = opName + + if idx < pCount: + tens_meta["input_type"] = "variable" + tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], None) + else: + tens_meta["input_type"] = "constant" + tens = testGen.ser.addConst(shape, dtypeList[idx], None) + tens_ser_list.append(tens) + + if dg_type == gtu.DataGenType.PSEUDO_RANDOM: + info = {} + # TODO - generate seed for this generator based on test + info["rng_seed"] = -1 + info["range"] = [ + str(v) + for v in testGen.getDTypeRange(dtypeList[idx], high_inclusive=True) + ] + tens_meta["pseudo_random_info"] = info + elif dg_type == gtu.DataGenType.DOT_PRODUCT: + info = {} + info["s"] = argsDict["s"] + info["ks"] = argsDict["ks"] + for key in gtu.DG_DOT_PRODUCT_OPTIONAL_INFO: + if key in argsDict: + if key.endswith("_type"): + info[key] = gtu.DTYPE_ATTRIBUTES[argsDict[key]]["json"] + else: + info[key] = argsDict[key] + tens_meta["dot_product_info"] = info + else: + # TODO - other data gen type + assert False, "TODO: support other data gen types" + dg_tens_meta[tens.name] = tens_meta + + tens_data = { + "version": "0.1", + "tensors": dg_tens_meta, + } + return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data) + + @staticmethod def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if dtypeList[0] == DType.INT32 and error_name is None: pCount, cCount = op["operands"] @@ -1024,6 +1110,50 @@ class TosaArgGen: pass @staticmethod + def _add_data_generators(testGen, opName, dtype, arg_list, error_name, **kwargs): + """Add extra tests for each type of data generator for this op.""" + if error_name is None and "data_gen" in testGen.TOSA_OP_LIST[opName]: + if dtype in [DType.FP16, DType.FP32, DType.BF16]: + dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"] + else: + dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"] + else: + # Error test or No data generator types listed - assume random + dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,) + + # Expand arg list with other data generator types + new_arg_list = [] + for dg_type in dataGenTypesList: + for arg_str, arg_attrs in arg_list: + arg_dict = arg_attrs[0] + arg_dict["dg_type"] = dg_type + + if dg_type == gtu.DataGenType.PSEUDO_RANDOM: + # Default test + new_arg_list.append((arg_str, [arg_dict])) + + elif dg_type == gtu.DataGenType.DOT_PRODUCT: + # Extra tests for each dot product test set + dot_products = kwargs["dot_products"] + if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN: + print( + f"Skipping dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}" + ) + continue + arg_dict["ks"] = kwargs["ks"] + for key in gtu.DG_DOT_PRODUCT_OPTIONAL_INFO: + if key in kwargs: + arg_dict[key] = kwargs[key] + + for s in testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS: + new_arg_str = f"{arg_str}_s{s}" + new_arg_dict = arg_dict.copy() + new_arg_dict["s"] = s + new_arg_list.append((new_arg_str, [new_arg_dict])) + + return new_arg_list + + @staticmethod def agNone(testGen, opName, shapeList, dtype, error_name=None): """A trivial argument generator for operators that don't take any non-tensor arguments""" @@ -1073,7 +1203,7 @@ class TosaArgGen: # Shape: (OFM channels), (KD), KH, KW, IFM channels filter_shape = shapeList[1] - accum_dtype = get_accum_dtype_from_tgTypes(dtypes) + accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes) # Check the rank conv3d = opName.startswith("conv3d") @@ -1258,12 +1388,12 @@ class TosaArgGen: input_dtype = dtypes[0] if error_name == ErrorIf.WrongOutputType: - accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype) + accum_dtype = gtu.get_wrong_output_type(opName, testGen.rng, input_dtype) elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect accum_dtype = DType.INT32 else: - accum_dtype = get_accum_dtype_from_tgTypes(dtypes) + accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes) return [(f"acc{testGen.typeStr(accum_dtype)}", [accum_dtype])] @@ -1285,12 +1415,28 @@ class TosaArgGen: if error_name == ErrorIf.WrongOutputType: # Get incorrect output dtype for ErrorIf case - accum_dtypes = [get_wrong_output_type(opName, testGen.rng, dtype)] + accum_dtypes = [gtu.get_wrong_output_type(opName, testGen.rng, dtype)] elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect accum_dtypes = [DType.INT32] - return [(f"acc{testGen.typeStr(a)}", [a]) for a in accum_dtypes] + arg_list = [ + (f"acc{testGen.typeStr(a)}", [{"acc_type": a}]) for a in accum_dtypes + ] + + arg_list = TosaArgGen._add_data_generators( + testGen, + opName, + dtype, + arg_list, + error_name, + ks=int(shapeList[0][2]), # Set KS = C, from input A (N,H,C) + # Set dot_products = N*H*W + dot_products=gtu.product( + (shapeList[0][0], shapeList[0][1], shapeList[1][2]) + ), + ) + return arg_list @staticmethod def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None): @@ -1303,7 +1449,7 @@ class TosaArgGen: ifm_shape = shapeList[0] filter_shape = shapeList[1] - accum_dtype = get_accum_dtype_from_tgTypes(dtypes) + accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes) # Must be rank 4 if error_name != ErrorIf.WrongRank: @@ -2288,9 +2434,9 @@ class TosaArgGen: if ( output_y <= 0 - or output_y >= MAX_RESIZE_DIMENSION + or output_y >= gtu.MAX_RESIZE_DIMENSION or output_x <= 0 - or output_x >= MAX_RESIZE_DIMENSION + or output_x >= gtu.MAX_RESIZE_DIMENSION ): # Output dimensions out of scope if error_name is not None and perm > 0: @@ -2301,11 +2447,11 @@ class TosaArgGen: if error_name == ErrorIf.ResizeOutputShapeMismatch and ( ( - output_y + scale_y_d >= MAX_RESIZE_DIMENSION + output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION and output_y - scale_y_d < 1 ) or ( - output_x + scale_x_d >= MAX_RESIZE_DIMENSION + output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION and output_x - scale_x_d < 1 ) ): |