From 587cc84c2b8c4b0d030b5e257c9a32461c0969b9 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Thu, 8 Feb 2024 11:45:44 +0000 Subject: Update test builder internal interfaces Move remaining ops from using testArgs to argsDict. All tvg/build_fcn function interfaces updated. Signed-off-by: Jeremy Johnson Change-Id: Ie886fd931bd74608bda621363100bf8bfd7385e6 --- verif/generator/tosa_arg_gen.py | 117 +++++++++++++++++++++++++++------------- 1 file changed, 79 insertions(+), 38 deletions(-) (limited to 'verif/generator/tosa_arg_gen.py') diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 33e74b5..7ec0cfe 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -636,18 +636,6 @@ class TosaTensorValuesGen: self.tensorList = tensorList self.dataGenDict = dataGenDict - @staticmethod - def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None): - pCount, cCount = op["operands"] - - tens = [] - tens.extend( - testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount]) - ) - tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:])) - - return tens - # Default high value for random numbers TVG_FLOAT_HIGH_VALUE = { DType.FP32: (1 << 128) - (1 << (127 - 23)), @@ -969,7 +957,7 @@ class TosaTensorValuesGen: @staticmethod def tvgCondIfWhileLoop( - testGen, op, dtypeList, shapeList, testArgs, error_name=None + testGen, opName, dtypeList, shapeList, argsDict, error_name=None ): if dtypeList[0] in ( DType.INT32, @@ -979,9 +967,10 @@ class TosaTensorValuesGen: # Limit input tensors with cond_if_binary or while_loop to stop # saturation of add/sub ops with int32 and keep all logical shift # values between 0 to 31 for int16 or int8 + op = testGen.TOSA_OP_LIST[opName] pCount, cCount = op["operands"] pRemain = pCount - placeholders = [] + tens_ser_list = [] for idx, shape in enumerate(shapeList[:]): if dtypeList[0] == DType.INT32: arr = testGen.getRandTensor(shapeList[idx], DType.INT16) @@ -990,32 +979,33 @@ class TosaTensorValuesGen: testGen.rng.integers(low=0, high=32, size=shapeList[idx]) ) if pRemain > 0: - placeholders.append( + tens_ser_list.append( testGen.ser.addPlaceholder(shape, dtypeList[idx], arr) ) pRemain -= 1 else: - placeholders.append( + tens_ser_list.append( testGen.ser.addConst(shape, dtypeList[idx], arr) ) - return placeholders + return TosaTensorValuesGen.TVGInfo(tens_ser_list, None) else: - return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, error_name + return TosaTensorValuesGen.tvgLazyGenDefault( + testGen, opName, dtypeList, shapeList, argsDict, error_name ) @staticmethod def tvgArithmeticRightShift( - testGen, op, dtypeList, shapeList, testArgs, error_name=None + testGen, opName, dtypeList, shapeList, argsDict, error_name=None ): + op = testGen.TOSA_OP_LIST[opName] pCount, cCount = op["operands"] # Force value of operand[1] to be within [0, num_bits] assert ( pCount == 2 and cCount == 0 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts" - placeholders = [] + tens_ser_list = [] for idx, shape in enumerate(shapeList[:]): if idx == 1: if dtypeList[idx] == DType.INT8: @@ -1030,23 +1020,23 @@ class TosaTensorValuesGen: raise Exception("OpArithmeticRightShift: invalid input dtype") else: arr = testGen.getRandTensor(shape, dtypeList[idx]) - placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)) + tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)) - return placeholders + return TosaTensorValuesGen.TVGInfo(tens_ser_list, None) @staticmethod - def tvgReshape(testGen, op, dtypeList, shapeList, argsDict, error_name=None): + def tvgReshape(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): dtypeList[1] = DType.SHAPE shapeList[1] = [len(argsDict["new_shape"])] # Create a new list for the pre-generated data in argsDict["fixed_data"] argsDict["fixed_data"] = [None, argsDict["new_shape"]] return TosaTensorValuesGen.tvgLazyGenDefault( - testGen, op, dtypeList, shapeList, argsDict, error_name + testGen, opName, dtypeList, shapeList, argsDict, error_name ) @staticmethod - def tvgPad(testGen, op, dtypeList, shapeList, argsDict, error_name=None): + def tvgPad(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): # argsDict["pad"] is 2D array, need to flatten it to get list of values pad_values = argsDict["pad"].flatten() dtypeList[1] = DType.SHAPE @@ -1055,11 +1045,11 @@ class TosaTensorValuesGen: argsDict["fixed_data"] = [None, pad_values] return TosaTensorValuesGen.tvgLazyGenDefault( - testGen, op, dtypeList, shapeList, argsDict, error_name + testGen, opName, dtypeList, shapeList, argsDict, error_name ) @staticmethod - def tvgSlice(testGen, op, dtypeList, shapeList, argsDict, error_name=None): + def tvgSlice(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): dtypeList[1] = DType.SHAPE shapeList[1] = [len(argsDict["start"])] dtypeList[2] = DType.SHAPE @@ -1068,17 +1058,17 @@ class TosaTensorValuesGen: argsDict["fixed_data"] = [None, argsDict["start"], argsDict["size"]] return TosaTensorValuesGen.tvgLazyGenDefault( - testGen, op, dtypeList, shapeList, argsDict, error_name + testGen, opName, dtypeList, shapeList, argsDict, error_name ) @staticmethod - def tvgTile(testGen, op, dtypeList, shapeList, argsDict, error_name=None): + def tvgTile(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): dtypeList[1] = DType.SHAPE shapeList[1] = [len(argsDict["multiples"])] argsDict["fixed_data"] = [None, argsDict["multiples"]] return TosaTensorValuesGen.tvgLazyGenDefault( - testGen, op, dtypeList, shapeList, argsDict, error_name + testGen, opName, dtypeList, shapeList, argsDict, error_name ) @staticmethod @@ -2776,10 +2766,23 @@ class TosaArgGen: int(double_round), int(per_channel), ), - [outDtype, scale32, double_round, per_channel], + { + "output_dtype": outDtype, + "scale": scale32, + "double_round": double_round, + "per_channel": per_channel, + }, ) ) + arg_list = TosaArgGen._add_data_generators( + testGen, + opName, + inDtype, + arg_list, + error_name, + ) + # Return list of tuples: (arg_str, args_dict) return arg_list @staticmethod @@ -2808,9 +2811,20 @@ class TosaArgGen: def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] - arg_list.append(("roundTrue", [True])) - arg_list.append(("roundFalse", [False])) + for round in (True, False): + args_dict = { + "round": round, + } + arg_list.append((f"round{round}", args_dict)) + arg_list = TosaArgGen._add_data_generators( + testGen, + opName, + dtype, + arg_list, + error_name, + ) + # Return list of tuples: (arg_str, args_dict) return arg_list @staticmethod @@ -3414,9 +3428,18 @@ class TosaArgGen: arg_list.append( ( "", - [table], + {"table": table}, ) ) + # Now add data generator types + arg_list = TosaArgGen._add_data_generators( + testGen, + opName, + dtype, + arg_list, + error_name, + ) + # Return list of tuples: (arg_str, args_dict) return arg_list def agCondIf(testGen, opName, shapeList, dtype, error_name=None): @@ -3426,15 +3449,33 @@ class TosaArgGen: arg_list = [] for c in [False, True]: - arg_list.append(("cond{}".format(int(c)), [c])) + arg_list.append(("cond{}".format(int(c)), {"condition": c})) + # Now add data generator types + arg_list = TosaArgGen._add_data_generators( + testGen, + opName, + dtype, + arg_list, + error_name, + ) + # Return list of tuples: (arg_str, args_dict) return arg_list def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None): # While loop: 0 iterations, 1, more than 1 arg_list = [] - for iter in [0, 1, 4]: - arg_list.append(("iter{}".format(iter), [iter])) + for iterations in [0, 1, 4]: + arg_list.append(("iter{}".format(iterations), {"iterations": iterations})) + # Now add data generator types + arg_list = TosaArgGen._add_data_generators( + testGen, + opName, + dtype, + arg_list, + error_name, + ) + # Return list of tuples: (arg_str, args_dict) return arg_list -- cgit v1.2.1