aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_error_if.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r--verif/generator/tosa_error_if.py16
1 files changed, 10 insertions, 6 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index ee227b3..b19d5e9 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -1067,7 +1067,9 @@ class TosaErrorValidator:
if check:
input1_shape = kwargs["input1"].shape
- input2_shape = kwargs["input2"].shape
+ input2_shape = (
+ kwargs["input2"].shape if "input2" in kwargs else input1_shape
+ )
# In case of SELECT op
input3_shape = (
kwargs["input3"].shape if "input3" in kwargs else input2_shape
@@ -1921,11 +1923,13 @@ class TosaErrorValidator:
input_shape = kwargs["input_shape"]
output_shape = kwargs["output_shape"]
size = kwargs["size"]
- rank = len(input_shape)
- if len(size) == rank:
- for index in range(rank):
- if size[index] != output_shape[index]:
- error_result = True
+
+ if len(input_shape) == len(output_shape):
+ rank = len(input_shape)
+ if len(size) == rank:
+ for index in range(rank):
+ if size[index] != output_shape[index]:
+ error_result = True
info_dict = {
"error_name": error_name,