diff options
Diffstat (limited to 'verif/generator/tosa_utils.py')
-rw-r--r-- | verif/generator/tosa_utils.py | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index 29ae898..8ff62f1 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -148,6 +148,17 @@ def get_wrong_output_type(op_name, rng, input_dtype): return rng.choice(a=incorrect_types) +def get_rank_mismatch_shape(rng, output_shape): + """ + Extends the rank of the provided output_shape by + an arbitrary amount but ensures the total element + count remains the same. + """ + rank_modifier = rng.choice([1, 2, 3]) + output_shape += [1] * rank_modifier + return output_shape + + def float32_is_valid_bfloat16(f): """Return True if float value is valid bfloat16.""" f32_bits = get_float32_bitstring(f) |