diff options
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r-- | verif/generator/tosa_error_if.py | 64 |
1 files changed, 62 insertions, 2 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index 93f975d..ee227b3 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -79,6 +79,8 @@ class ErrorIf(object): CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne" CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne" KernelNotPowerOfTwo = "KernelNotPowerOfTwo" + FFTInputShapeMismatch = "FFTInputShapeMismatch" + FFTOutputShapeMismatch = "FFTOutputShapeMismatch" class TosaErrorIfArgGen: @@ -562,7 +564,7 @@ class TosaErrorValidator: ): error_result = True - elif op["op"] == Op.RFFT2D: + elif op["op"] in [Op.FFT2D, Op.RFFT2D]: if not all([ty == input_dtype for ty in output_dtype]): error_result = True @@ -686,7 +688,7 @@ class TosaErrorValidator: op = kwargs["op"] output_list = kwargs["output_list"] expected_length = 1 - if op["op"] == Op.RFFT2D: + if op["op"] in [Op.FFT2D, Op.RFFT2D]: expected_length = 2 if len(output_list) != expected_length: @@ -2446,6 +2448,64 @@ class TosaErrorValidator: } return info_dict + @staticmethod + def evFFTInputShapeMismatch(check=False, **kwargs): + error_name = ErrorIf.FFTInputShapeMismatch + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Mismatch between real and imaginary input shapes" + + if check: + input1 = kwargs["input1"] + input2 = kwargs["input2"] + + if input1.shape != input2.shape: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + + @staticmethod + def evFFTOutputShapeMismatch(check=False, **kwargs): + error_name = ErrorIf.FFTOutputShapeMismatch + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = ( + "Mismatch between provided and expected output kernel (H, W) shape" + ) + + if check: + op = kwargs["op"] + input_shape = kwargs["input_shape"] + + if len(input_shape) == 3: + output_shapes = kwargs["output_shape"] + + # Ignoring batch size (N) from input shape + expected_shape = input_shape[1:] + if op["op"] == Op.RFFT2D: + expected_shape[1] = expected_shape[1] // 2 + 1 + + # Ignoring batch size (N) from output shapes + output_shape_0 = output_shapes[0][1:] + output_shape_1 = output_shapes[1][1:] + # Ensure sure the kernel sizes (H, W) of both outputs match the expected + if output_shape_0 != output_shape_1 or output_shape_0 != expected_shape: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + class TosaInvalidValidator: @staticmethod |