aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_error_if.py
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2023-05-23 20:59:32 +0000
committerDominic Symes <dominic.symes@arm.com>2023-06-15 18:25:54 +0000
commit135c95544fda260e8ce622cff7835b886a97663f (patch)
tree5d46f8f48978112abff037309a827b5844ee80de /verif/generator/tosa_error_if.py
parentcb7201e173961760c042cade591afe763c949c8f (diff)
downloadreference_model-135c95544fda260e8ce622cff7835b886a97663f.tar.gz
Add ERROR_IF to incorrect broadcast shapes
Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I7460ad9eed3ed5c7cec6e855a0303753ed28eb1c
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r--verif/generator/tosa_error_if.py72
1 files changed, 61 insertions, 11 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index a0a9203..d490cf2 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -83,6 +83,7 @@ class ErrorIf(object):
FFTOutputShapeMismatch = "FFTOutputShapeMismatch"
ReshapeOutputSizeMultiInference = "ReshapeOutputSizeMultiInference"
ReshapeOutputSizeNonInteger = "ReshapeOutputSizeNonInteger"
+ BroadcastShapesMismatch = "BroadcastShapesMismatch"
class TosaErrorIfArgGen:
@@ -1109,17 +1110,19 @@ class TosaErrorValidator:
kwargs["input3"].shape if "input3" in kwargs else input2_shape
)
- for output in kwargs["result_tensors"]:
- output_shape = output.shape
- for i in range(
- min(len(input1_shape), len(input2_shape), len(input3_shape))
- ):
- if (
- (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
- or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
- or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
- ):
- error_result = True
+ if len(input1_shape) == len(input2_shape) == len(input3_shape):
+ calculated_shape = TosaErrorValidator.calculateBroadcastShape(
+ input3_shape,
+ TosaErrorValidator.calculateBroadcastShape(
+ input1_shape, input2_shape
+ ),
+ )
+ if calculated_shape is not None:
+ # Valid inputs - check for output mismatch
+ for output in kwargs["result_tensors"]:
+ output_shape = output.shape
+ if calculated_shape != output_shape:
+ error_result = True
info_dict = {
"error_name": error_name,
@@ -2566,6 +2569,53 @@ class TosaErrorValidator:
}
return info_dict
+ @staticmethod
+ def calculateBroadcastShape(input_shape_a, input_shape_b):
+ if input_shape_a is not None and input_shape_b is not None:
+ calculated_shape = input_shape_a.copy()
+ for idx in range(len(calculated_shape)):
+ if calculated_shape[idx] == 1:
+ calculated_shape[idx] = input_shape_b[idx]
+ elif (
+ input_shape_b[idx] != 1
+ and input_shape_b[idx] != calculated_shape[idx]
+ ):
+ return None
+ return calculated_shape
+ else:
+ return None
+
+ @staticmethod
+ def evBroadcastShapesMismatch(check=False, **kwargs):
+ error_name = ErrorIf.BroadcastShapesMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Broadcast shape calculating failed"
+
+ if check:
+ input_shape_a = kwargs["input1"].shape
+ input_shape_b = kwargs["input2"].shape
+ input_shape_c = (
+ kwargs["input3"].shape if "input3" in kwargs else input_shape_b
+ )
+
+ if len(input_shape_a) == len(input_shape_b) == len(input_shape_c):
+ calculated_shape = TosaErrorValidator.calculateBroadcastShape(
+ input_shape_c,
+ TosaErrorValidator.calculateBroadcastShape(
+ input_shape_a, input_shape_b
+ ),
+ )
+ error_result = calculated_shape is None
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
class TosaInvalidValidator:
@staticmethod