diff options
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 83 |
1 files changed, 18 insertions, 65 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 50811ac..1e23822 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -2673,97 +2673,50 @@ class TosaArgGen: arg_list = [] origShape = shapeList[0] - - totalElements = 1 - for s in origShape: - totalElements *= s - - # This code is NOT fast. Fortunately, the numbers are fairly small. + totalElements = gtu.product(origShape) factors = TosaArgGen.getFactors(totalElements) + # Find new shapes up to the number of permutations asked for + # This code is NOT fast. Fortunately, the numbers are fairly small. for p in range(testGen.args.num_rand_permutations): # Rank from 1 to TOSA_TENSOR_MAX_RANK newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1)) if len(factors) < newRank: continue - found = True - # escape_counter breaks while loop if it continues on for too long - escape_counter = 0 - while found: + # escape_counter limits the generation of new shapes to a reasonable time + for escape_counter in range(100): + + # Generate the new shape of the chosen new rank newShape = [] - new_shape_inferred = [] - # Generate newShape ensuring it isn't a duplicate remainingElements = totalElements shuffledFactors = testGen.rng.permutation(factors) - inferred_dim = testGen.rng.integers(1, newRank + 1) for i in range(1, newRank): # pick rank-1 factors newShape.append(shuffledFactors[0]) remainingElements = remainingElements // shuffledFactors[0] - if i == inferred_dim: - new_shape_inferred.append(-1) - else: - new_shape_inferred.append(shuffledFactors[0]) shuffledFactors = testGen.rng.permutation( TosaArgGen.getFactors(remainingElements) ) newShape.append(remainingElements) - if inferred_dim == newRank: - new_shape_inferred.append(-1) - else: - new_shape_inferred.append(remainingElements) # Check for duplicates - found = False + duplicate = False for name, args_dict in arg_list: if args_dict["new_shape"] == newShape: - found = True + duplicate = True break - escape_counter += 1 - if escape_counter >= 100: - break - - if not found: - if error_name in [ - ErrorIf.ReshapeOutputSizeNonInteger, - ErrorIf.ReshapeOutputSizeMultiInference, - ]: - if newRank < 2: - # Need at least two dimensions - continue - # NOTE: Change inferred_dim starting offset from 1 to 0 - inferred_dim -= 1 - extra_dim = inferred_dim + testGen.rng.integers(1, newRank) - extra_dim = extra_dim % newRank - assert extra_dim != inferred_dim - if error_name == ErrorIf.ReshapeOutputSizeNonInteger: - elements = 1 - for i, dim_value in enumerate(new_shape_inferred): - if i != inferred_dim and i != extra_dim: - elements *= dim_value - dim_value = new_shape_inferred[extra_dim] - while totalElements % (elements * dim_value) == 0: - dim_value += 1 - new_shape_inferred[extra_dim] = dim_value - else: - assert error_name == ErrorIf.ReshapeOutputSizeMultiInference - new_shape_inferred[extra_dim] = -1 - else: - arg_list.append( - ( - "perm{}_rank{}_outdefined".format(p, newRank), - {"new_shape": newShape}, - ) - ) - if error_name != ErrorIf.TensorSizeInputOutputMismatch: - arg_list.append( - ( - "perm{}_rank{}_outinferred".format(p, newRank), - {"new_shape": new_shape_inferred}, - ) + if not duplicate: + outShape = "x".join([str(x) for x in newShape]) + arg_list.append( + ( + "perm{}_rank{}_out{}".format(p, newRank, outShape), + {"new_shape": newShape}, ) + ) + # Found an output shape for this permutation + break # Now add data generator types arg_list = TosaArgGen._add_data_generators( |