diff options
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 79 |
1 files changed, 62 insertions, 17 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index a655a50..f598377 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -622,6 +622,28 @@ class TosaTensorGen: return new_shapeList + @staticmethod + def tgShape(testGen, opName, rank, error_name=None): + pl, const = opName["operands"] + shape = [rank] + + # Constrict the overall size of the shape when creating ERROR_IF tests + if error_name: + shape = TosaErrorIfArgGen.eiRestrictDimensions(shape) + + shape_list = [] + for i in range(pl + const): + shape_list.append(shape.copy()) + + # Generates an input rank mismatch for operators with more than one input + if error_name == ErrorIf.RankMismatch: + if rank == 1 and i != 1: + shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3])) + elif i != 1: + shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1])) + + return shape_list + class TosaTensorValuesGen: """Tensor Value generators create the random data for each tensor in each test.""" @@ -891,7 +913,7 @@ class TosaTensorValuesGen: @staticmethod def tvgAddSub(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): - if dtypeList[0] == DType.INT32 and error_name is None: + if dtypeList[0] in (DType.INT32, DType.SHAPE) and error_name is None: # Make sure the integer operation does not cause value saturation - where # the number wraps due to limited number of bits to store the answer op = testGen.TOSA_OP_LIST[opName] @@ -900,9 +922,10 @@ class TosaTensorValuesGen: pCount == 2 and cCount == 0 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts" tens_ser_list = [] - add = op["op"] == Op.ADD - a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0]) - b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1]) + add = op["op"] in (Op.ADD, Op.ADD_SHAPE) + data_range = testGen.args.tensor_shape_range + a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0], data_range) + b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1], data_range) if add: res_arr = np.add(a_arr, b_arr, dtype=np.int64) else: @@ -1138,12 +1161,15 @@ class TosaTensorValuesGen: tens_ser_list = [] # Make sure multiply result in int32 range - shift = argsDict["shift"] + if dtypeList[0] == DType.SHAPE: + shift = 0 + else: + shift = argsDict["shift"] if dtypeList[0] == DType.INT8: num_bits = 8 elif dtypeList[0] == DType.INT16: num_bits = 16 - elif dtypeList[0] == DType.INT32: + elif dtypeList[0] in (DType.INT32, DType.SHAPE): num_bits = 32 elif error_name == ErrorIf.WrongInputType: num_bits = 8 @@ -1151,8 +1177,12 @@ class TosaTensorValuesGen: raise Exception("OpMul: invalid input dtype") for idx, shape in enumerate(shapeList[:]): - low = -(2 ** (num_bits - 1)) - high = (2 ** (num_bits - 1)) - 1 + if dtypeList[idx] == DType.SHAPE: + low = testGen.args.tensor_shape_range[0] + high = testGen.args.tensor_shape_range[1] + else: + low = -(2 ** (num_bits - 1)) + high = (2 ** (num_bits - 1)) - 1 a_arr = np.int32( testGen.rng.integers(low=low, high=high, size=shapeList[0]) @@ -1182,12 +1212,20 @@ class TosaTensorValuesGen: a_arr = a_arr // 2 b_arr = b_arr // 2 - tens_ser_list.append( - testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr) - ) - tens_ser_list.append( - testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr) - ) + if dtypeList[0] == DType.SHAPE: + tens_ser_list.append( + testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64) + ) + tens_ser_list.append( + testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64) + ) + else: + tens_ser_list.append( + testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr) + ) + tens_ser_list.append( + testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr) + ) return TosaTensorValuesGen.TVGInfo(tens_ser_list, None) @@ -1199,9 +1237,16 @@ class TosaTensorValuesGen: if testGen.args.num_const_inputs_concat == 0: count = len(shapeList) - shapeList = TosaTensorGen.tgConcatConstInput( - testGen, shapeList, argsDict["axis"], error_name - ) + op = testGen.TOSA_OP_LIST[opName] + if op["op"] == Op.CONCAT_SHAPE: + # Set the axis to 0 + shapeList = TosaTensorGen.tgConcatConstInput( + testGen, shapeList, 0, error_name + ) + else: + shapeList = TosaTensorGen.tgConcatConstInput( + testGen, shapeList, argsDict["axis"], error_name + ) # Override default pCount/cCount for operator argsDict["p_count"] = count |