diff options
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r-- | verif/generator/tosa_error_if.py | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index ee227b3..b19d5e9 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -1067,7 +1067,9 @@ class TosaErrorValidator: if check: input1_shape = kwargs["input1"].shape - input2_shape = kwargs["input2"].shape + input2_shape = ( + kwargs["input2"].shape if "input2" in kwargs else input1_shape + ) # In case of SELECT op input3_shape = ( kwargs["input3"].shape if "input3" in kwargs else input2_shape @@ -1921,11 +1923,13 @@ class TosaErrorValidator: input_shape = kwargs["input_shape"] output_shape = kwargs["output_shape"] size = kwargs["size"] - rank = len(input_shape) - if len(size) == rank: - for index in range(rank): - if size[index] != output_shape[index]: - error_result = True + + if len(input_shape) == len(output_shape): + rank = len(input_shape) + if len(size) == rank: + for index in range(rank): + if size[index] != output_shape[index]: + error_result = True info_dict = { "error_name": error_name, |