From 135c95544fda260e8ce622cff7835b886a97663f Mon Sep 17 00:00:00 2001 From: Jerry Ge Date: Tue, 23 May 2023 20:59:32 +0000 Subject: Add ERROR_IF to incorrect broadcast shapes Signed-off-by: Jerry Ge Change-Id: I7460ad9eed3ed5c7cec6e855a0303753ed28eb1c --- verif/generator/tosa_error_if.py | 72 ++++++++++++++++++++++++++++++++++------ 1 file changed, 61 insertions(+), 11 deletions(-) (limited to 'verif/generator/tosa_error_if.py') diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index a0a9203..d490cf2 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -83,6 +83,7 @@ class ErrorIf(object): FFTOutputShapeMismatch = "FFTOutputShapeMismatch" ReshapeOutputSizeMultiInference = "ReshapeOutputSizeMultiInference" ReshapeOutputSizeNonInteger = "ReshapeOutputSizeNonInteger" + BroadcastShapesMismatch = "BroadcastShapesMismatch" class TosaErrorIfArgGen: @@ -1109,17 +1110,19 @@ class TosaErrorValidator: kwargs["input3"].shape if "input3" in kwargs else input2_shape ) - for output in kwargs["result_tensors"]: - output_shape = output.shape - for i in range( - min(len(input1_shape), len(input2_shape), len(input3_shape)) - ): - if ( - (input1_shape[i] != 1 and input1_shape[i] != output_shape[i]) - or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i]) - or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i]) - ): - error_result = True + if len(input1_shape) == len(input2_shape) == len(input3_shape): + calculated_shape = TosaErrorValidator.calculateBroadcastShape( + input3_shape, + TosaErrorValidator.calculateBroadcastShape( + input1_shape, input2_shape + ), + ) + if calculated_shape is not None: + # Valid inputs - check for output mismatch + for output in kwargs["result_tensors"]: + output_shape = output.shape + if calculated_shape != output_shape: + error_result = True info_dict = { "error_name": error_name, @@ -2566,6 +2569,53 @@ class TosaErrorValidator: } return info_dict + @staticmethod + def calculateBroadcastShape(input_shape_a, input_shape_b): + if input_shape_a is not None and input_shape_b is not None: + calculated_shape = input_shape_a.copy() + for idx in range(len(calculated_shape)): + if calculated_shape[idx] == 1: + calculated_shape[idx] = input_shape_b[idx] + elif ( + input_shape_b[idx] != 1 + and input_shape_b[idx] != calculated_shape[idx] + ): + return None + return calculated_shape + else: + return None + + @staticmethod + def evBroadcastShapesMismatch(check=False, **kwargs): + error_name = ErrorIf.BroadcastShapesMismatch + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Broadcast shape calculating failed" + + if check: + input_shape_a = kwargs["input1"].shape + input_shape_b = kwargs["input2"].shape + input_shape_c = ( + kwargs["input3"].shape if "input3" in kwargs else input_shape_b + ) + + if len(input_shape_a) == len(input_shape_b) == len(input_shape_c): + calculated_shape = TosaErrorValidator.calculateBroadcastShape( + input_shape_c, + TosaErrorValidator.calculateBroadcastShape( + input_shape_a, input_shape_b + ), + ) + error_result = calculated_shape is None + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + class TosaInvalidValidator: @staticmethod -- cgit v1.2.1