aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py83
1 files changed, 18 insertions, 65 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 50811ac..1e23822 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -2673,97 +2673,50 @@ class TosaArgGen:
arg_list = []
origShape = shapeList[0]
-
- totalElements = 1
- for s in origShape:
- totalElements *= s
-
- # This code is NOT fast. Fortunately, the numbers are fairly small.
+ totalElements = gtu.product(origShape)
factors = TosaArgGen.getFactors(totalElements)
+ # Find new shapes up to the number of permutations asked for
+ # This code is NOT fast. Fortunately, the numbers are fairly small.
for p in range(testGen.args.num_rand_permutations):
# Rank from 1 to TOSA_TENSOR_MAX_RANK
newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
if len(factors) < newRank:
continue
- found = True
- # escape_counter breaks while loop if it continues on for too long
- escape_counter = 0
- while found:
+ # escape_counter limits the generation of new shapes to a reasonable time
+ for escape_counter in range(100):
+
+ # Generate the new shape of the chosen new rank
newShape = []
- new_shape_inferred = []
- # Generate newShape ensuring it isn't a duplicate
remainingElements = totalElements
shuffledFactors = testGen.rng.permutation(factors)
- inferred_dim = testGen.rng.integers(1, newRank + 1)
for i in range(1, newRank):
# pick rank-1 factors
newShape.append(shuffledFactors[0])
remainingElements = remainingElements // shuffledFactors[0]
- if i == inferred_dim:
- new_shape_inferred.append(-1)
- else:
- new_shape_inferred.append(shuffledFactors[0])
shuffledFactors = testGen.rng.permutation(
TosaArgGen.getFactors(remainingElements)
)
newShape.append(remainingElements)
- if inferred_dim == newRank:
- new_shape_inferred.append(-1)
- else:
- new_shape_inferred.append(remainingElements)
# Check for duplicates
- found = False
+ duplicate = False
for name, args_dict in arg_list:
if args_dict["new_shape"] == newShape:
- found = True
+ duplicate = True
break
- escape_counter += 1
- if escape_counter >= 100:
- break
-
- if not found:
- if error_name in [
- ErrorIf.ReshapeOutputSizeNonInteger,
- ErrorIf.ReshapeOutputSizeMultiInference,
- ]:
- if newRank < 2:
- # Need at least two dimensions
- continue
- # NOTE: Change inferred_dim starting offset from 1 to 0
- inferred_dim -= 1
- extra_dim = inferred_dim + testGen.rng.integers(1, newRank)
- extra_dim = extra_dim % newRank
- assert extra_dim != inferred_dim
- if error_name == ErrorIf.ReshapeOutputSizeNonInteger:
- elements = 1
- for i, dim_value in enumerate(new_shape_inferred):
- if i != inferred_dim and i != extra_dim:
- elements *= dim_value
- dim_value = new_shape_inferred[extra_dim]
- while totalElements % (elements * dim_value) == 0:
- dim_value += 1
- new_shape_inferred[extra_dim] = dim_value
- else:
- assert error_name == ErrorIf.ReshapeOutputSizeMultiInference
- new_shape_inferred[extra_dim] = -1
- else:
- arg_list.append(
- (
- "perm{}_rank{}_outdefined".format(p, newRank),
- {"new_shape": newShape},
- )
- )
- if error_name != ErrorIf.TensorSizeInputOutputMismatch:
- arg_list.append(
- (
- "perm{}_rank{}_outinferred".format(p, newRank),
- {"new_shape": new_shape_inferred},
- )
+ if not duplicate:
+ outShape = "x".join([str(x) for x in newShape])
+ arg_list.append(
+ (
+ "perm{}_rank{}_out{}".format(p, newRank, outShape),
+ {"new_shape": newShape},
)
+ )
+ # Found an output shape for this permutation
+ break
# Now add data generator types
arg_list = TosaArgGen._add_data_generators(