aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_utils.py')
-rw-r--r--verif/generator/tosa_utils.py11
1 files changed, 11 insertions, 0 deletions
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 29ae898..8ff62f1 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -148,6 +148,17 @@ def get_wrong_output_type(op_name, rng, input_dtype):
return rng.choice(a=incorrect_types)
+def get_rank_mismatch_shape(rng, output_shape):
+ """
+ Extends the rank of the provided output_shape by
+ an arbitrary amount but ensures the total element
+ count remains the same.
+ """
+ rank_modifier = rng.choice([1, 2, 3])
+ output_shape += [1] * rank_modifier
+ return output_shape
+
+
def float32_is_valid_bfloat16(f):
"""Return True if float value is valid bfloat16."""
f32_bits = get_float32_bitstring(f)