diff options
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 46 |
1 files changed, 45 insertions, 1 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 2bbc349..9386ec2 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1878,17 +1878,27 @@ class TosaArgGen: escape_counter = 0 while found: newShape = [] + new_shape_inferred = [] # Generate newShape ensuring it isn't a duplicate remainingElements = totalElements shuffledFactors = testGen.rng.permutation(factors) + inferred_dim = testGen.rng.integers(1, newRank + 1) for i in range(1, newRank): # pick rank-1 factors newShape.append(shuffledFactors[0]) remainingElements = remainingElements // shuffledFactors[0] + if i == inferred_dim: + new_shape_inferred.append(-1) + else: + new_shape_inferred.append(shuffledFactors[0]) shuffledFactors = testGen.rng.permutation( TosaArgGen.getFactors(remainingElements) ) newShape.append(remainingElements) + if inferred_dim == newRank: + new_shape_inferred.append(-1) + else: + new_shape_inferred.append(remainingElements) # Check for duplicates found = False @@ -1902,7 +1912,41 @@ class TosaArgGen: break if not found: - arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape])) + if error_name in [ + ErrorIf.ReshapeOutputSizeNonInteger, + ErrorIf.ReshapeOutputSizeMultiInference, + ]: + if newRank < 2: + # Need at least two dimensions + continue + # NOTE: Change inferred_dim starting offset from 1 to 0 + inferred_dim -= 1 + extra_dim = inferred_dim + testGen.rng.integers(1, newRank) + extra_dim = extra_dim % newRank + assert extra_dim != inferred_dim + if error_name == ErrorIf.ReshapeOutputSizeNonInteger: + elements = 1 + for i, dim_value in enumerate(new_shape_inferred): + if i != inferred_dim and i != extra_dim: + elements *= dim_value + dim_value = new_shape_inferred[extra_dim] + while totalElements % (elements * dim_value) == 0: + dim_value += 1 + new_shape_inferred[extra_dim] = dim_value + else: + assert error_name == ErrorIf.ReshapeOutputSizeMultiInference + new_shape_inferred[extra_dim] = -1 + else: + arg_list.append( + ("perm{}_rank{}_outdefined".format(p, newRank), [newShape]) + ) + if error_name != ErrorIf.TensorSizeInputOutputMismatch: + arg_list.append( + ( + "perm{}_rank{}_outinferred".format(p, newRank), + [new_shape_inferred], + ) + ) return arg_list |